diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index dabc1c5f09..ef4cf8d0f1 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,32 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
import synapse.api.errors
-import tests.unittest
-import tests.utils
-
-
-class DeviceStoreTestCase(tests.unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.store = None # type: synapse.storage.DataStore
+from tests.unittest import HomeserverTestCase
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class DeviceStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
def test_store_new_device(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device_id", "display_name")
)
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -48,19 +37,18 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res,
)
- @defer.inlineCallbacks
def test_get_devices_by_user(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device2", "display_name 2")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3")
)
- res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
+ res = self.get_success(self.store.get_devices_by_user("user_id"))
self.assertEqual(2, len(res.keys()))
self.assertDictContainsSubset(
{
@@ -79,43 +67,41 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res["device2"],
)
- @defer.inlineCallbacks
def test_count_devices_by_users(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device2", "display_name 2")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3")
)
- res = yield defer.ensureDeferred(self.store.count_devices_by_users())
+ res = self.get_success(self.store.count_devices_by_users())
self.assertEqual(0, res)
- res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
+ res = self.get_success(self.store.count_devices_by_users(["unknown"]))
self.assertEqual(0, res)
- res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
+ res = self.get_success(self.store.count_devices_by_users(["user_id"]))
self.assertEqual(2, res)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.count_devices_by_users(["user_id", "user_id2"])
)
self.assertEqual(3, res)
- @defer.inlineCallbacks
def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
- yield defer.ensureDeferred(
+ self.get_success(
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
)
# Get all device updates ever meant for this remote
- now_stream_id, device_updates = yield defer.ensureDeferred(
+ now_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", -1, limit=100)
)
@@ -131,37 +117,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
}
self.assertEqual(received_device_ids, set(expected_device_ids))
- @defer.inlineCallbacks
def test_update_device(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device_id", "display_name 1")
)
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
- yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ self.get_success(self.store.update_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do the update
- yield defer.ensureDeferred(
+ self.get_success(
self.store.update_device(
"user_id", "device_id", new_display_name="display_name 2"
)
)
# check it worked
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"])
- @defer.inlineCallbacks
def test_update_unknown_device(self):
- with self.assertRaises(synapse.api.errors.StoreError) as cm:
- yield defer.ensureDeferred(
- self.store.update_device(
- "user_id", "unknown_device_id", new_display_name="display_name 2"
- )
- )
- self.assertEqual(404, cm.exception.code)
+ exc = self.get_failure(
+ self.store.update_device(
+ "user_id", "unknown_device_id", new_display_name="display_name 2"
+ ),
+ synapse.api.errors.StoreError,
+ )
+ self.assertEqual(404, exc.value.code)
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index da93ca3980..0db233fd68 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,28 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.types import RoomAlias, RoomID
-from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase
-class DirectoryStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
-
+class DirectoryStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")
- @defer.inlineCallbacks
def test_room_to_alias(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
@@ -42,16 +34,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertEquals(
["#my-room:test"],
- (
- yield defer.ensureDeferred(
- self.store.get_aliases_for_room(self.room.to_string())
- )
- ),
+ (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
)
- @defer.inlineCallbacks
def test_alias_to_room(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
@@ -59,28 +46,19 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertObjectHasAttributes(
{"room_id": self.room.to_string(), "servers": ["test"]},
- (
- yield defer.ensureDeferred(
- self.store.get_association_from_room_alias(self.alias)
- )
- ),
+ (self.get_success(self.store.get_association_from_room_alias(self.alias))),
)
- @defer.inlineCallbacks
def test_delete_alias(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
)
- room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
+ room_id = self.get_success(self.store.delete_room_alias(self.alias))
self.assertEqual(self.room.to_string(), room_id)
self.assertIsNone(
- (
- yield defer.ensureDeferred(
- self.store.get_association_from_room_alias(self.alias)
- )
- )
+ (self.get_success(self.store.get_association_from_room_alias(self.alias)))
)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 3fc4bb13b6..1e54b940fd 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,30 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+from tests.unittest import HomeserverTestCase
-import tests.unittest
-import tests.utils
-
-class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EndToEndKeyStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
def test_key_without_device_name(self):
now = 1470174257070
json = {"key": "value"}
- yield defer.ensureDeferred(self.store.store_device("user", "device", None))
+ self.get_success(self.store.store_device("user", "device", None))
- yield defer.ensureDeferred(
- self.store.set_e2e_device_keys("user", "device", now, json)
- )
+ self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
@@ -44,38 +36,32 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
dev = res["user"]["device"]
self.assertDictContainsSubset(json, dev)
- @defer.inlineCallbacks
def test_reupload_key(self):
now = 1470174257070
json = {"key": "value"}
- yield defer.ensureDeferred(self.store.store_device("user", "device", None))
+ self.get_success(self.store.store_device("user", "device", None))
- changed = yield defer.ensureDeferred(
+ changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json)
)
self.assertTrue(changed)
# If we try to upload the same key then we should be told nothing
# changed
- changed = yield defer.ensureDeferred(
+ changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json)
)
self.assertFalse(changed)
- @defer.inlineCallbacks
def test_get_key_with_device_name(self):
now = 1470174257070
json = {"key": "value"}
- yield defer.ensureDeferred(
- self.store.set_e2e_device_keys("user", "device", now, json)
- )
- yield defer.ensureDeferred(
- self.store.store_device("user", "device", "display_name")
- )
+ self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
+ self.get_success(self.store.store_device("user", "device", "display_name"))
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
@@ -85,29 +71,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
)
- @defer.inlineCallbacks
def test_multiple_devices(self):
now = 1470174257070
- yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
- yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
- yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
- yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
+ self.get_success(self.store.store_device("user1", "device1", None))
+ self.get_success(self.store.store_device("user1", "device2", None))
+ self.get_success(self.store.store_device("user2", "device1", None))
+ self.get_success(self.store.store_device("user2", "device2", None))
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api(
(("user1", "device1"), ("user2", "device2"))
)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 485f1ee033..239f7c9faf 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,10 +15,7 @@
from mock import Mock
-from twisted.internet import defer
-
-import tests.unittest
-import tests.utils
+from tests.unittest import HomeserverTestCase
USER_ID = "@user:example.com"
@@ -30,37 +27,31 @@ HIGHLIGHT = [
]
-class EventPushActionsStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EventPushActionsStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.persist_events_store = hs.get_datastores().persist_events
- @defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http(
USER_ID, 0, 1000, 20
)
)
- @defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_email(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
USER_ID, 0, 1000, 20
)
)
- @defer.inlineCallbacks
def test_count_aggregation(self):
room_id = "!foo:example.com"
user_id = "@user1235:example.com"
- @defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
- counts = yield defer.ensureDeferred(
+ counts = self.get_success(
self.store.db_pool.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
)
@@ -74,7 +65,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
},
)
- @defer.inlineCallbacks
def _inject_actions(stream, action):
event = Mock()
event.room_id = room_id
@@ -82,14 +72,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream
event.depth = stream
- yield defer.ensureDeferred(
+ self.get_success(
self.store.add_push_actions_to_staging(
event.event_id,
{user_id: action},
False,
)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.runInteraction(
"",
self.persist_events_store._set_push_actions_for_event_and_users_txn,
@@ -99,14 +89,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
def _rotate(stream):
- return defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.runInteraction(
"", self.store._rotate_notifs_before_txn, stream
)
)
def _mark_read(stream, depth):
- return defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.runInteraction(
"",
self.store._remove_old_push_actions_before_txn,
@@ -116,49 +106,48 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
)
- yield _assert_counts(0, 0)
- yield _inject_actions(1, PlAIN_NOTIF)
- yield _assert_counts(1, 0)
- yield _rotate(2)
- yield _assert_counts(1, 0)
+ _assert_counts(0, 0)
+ _inject_actions(1, PlAIN_NOTIF)
+ _assert_counts(1, 0)
+ _rotate(2)
+ _assert_counts(1, 0)
- yield _inject_actions(3, PlAIN_NOTIF)
- yield _assert_counts(2, 0)
- yield _rotate(4)
- yield _assert_counts(2, 0)
+ _inject_actions(3, PlAIN_NOTIF)
+ _assert_counts(2, 0)
+ _rotate(4)
+ _assert_counts(2, 0)
- yield _inject_actions(5, PlAIN_NOTIF)
- yield _mark_read(3, 3)
- yield _assert_counts(1, 0)
+ _inject_actions(5, PlAIN_NOTIF)
+ _mark_read(3, 3)
+ _assert_counts(1, 0)
- yield _mark_read(5, 5)
- yield _assert_counts(0, 0)
+ _mark_read(5, 5)
+ _assert_counts(0, 0)
- yield _inject_actions(6, PlAIN_NOTIF)
- yield _rotate(7)
+ _inject_actions(6, PlAIN_NOTIF)
+ _rotate(7)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.simple_delete(
table="event_push_actions", keyvalues={"1": 1}, desc=""
)
)
- yield _assert_counts(1, 0)
+ _assert_counts(1, 0)
- yield _mark_read(7, 7)
- yield _assert_counts(0, 0)
+ _mark_read(7, 7)
+ _assert_counts(0, 0)
- yield _inject_actions(8, HIGHLIGHT)
- yield _assert_counts(1, 1)
- yield _rotate(9)
- yield _assert_counts(1, 1)
- yield _rotate(10)
- yield _assert_counts(1, 1)
+ _inject_actions(8, HIGHLIGHT)
+ _assert_counts(1, 1)
+ _rotate(9)
+ _assert_counts(1, 1)
+ _rotate(10)
+ _assert_counts(1, 1)
- @defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.simple_insert(
"events",
{
@@ -177,24 +166,16 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
# start with the base case where there are no events in the table
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(11)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.assertEqual(r, 0)
# now with one event
- yield add_event(2, 10)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(9)
- )
+ add_event(2, 10)
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(9))
self.assertEqual(r, 2)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(10)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(10))
self.assertEqual(r, 2)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(11)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.assertEqual(r, 3)
# add a bunch of dummy events to the events table
@@ -205,39 +186,27 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
(10, 130),
(20, 140),
):
- yield add_event(stream_ordering, ts)
+ add_event(stream_ordering, ts)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(110)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(110))
self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
# 4 and 5 are both after 120: we want 4 rather than 5
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(120)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(120))
self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(129)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(129))
self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
# check we can get the last event
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(140)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(140))
self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
# off the end
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(160)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(160))
self.assertEqual(r, 21)
# check we can find an event at ordering zero
- yield add_event(0, 5)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(1)
- )
+ add_event(0, 5)
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(1))
self.assertEqual(r, 0)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index ea63bd56b4..d18ceb41a9 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,59 +13,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.types import UserID
from tests import unittest
-from tests.utils import setup_test_homeserver
-
-class ProfileStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
+class ProfileStoreTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.u_frank = UserID.from_string("@frank:test")
- @defer.inlineCallbacks
def test_displayname(self):
- yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+ self.get_success(self.store.create_profile(self.u_frank.localpart))
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
)
self.assertEquals(
"Frank",
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart)
)
),
)
# test set to None
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_displayname(self.u_frank.localpart, None)
)
self.assertIsNone(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart)
)
)
)
- @defer.inlineCallbacks
def test_avatar_url(self):
- yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+ self.get_success(self.store.create_profile(self.u_frank.localpart))
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
)
@@ -74,20 +65,20 @@ class ProfileStoreTestCase(unittest.TestCase):
self.assertEquals(
"http://my.site/here",
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
),
)
# test set to None
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_avatar_url(self.u_frank.localpart, None)
)
self.assertIsNone(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index b2a0e60856..2622207639 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,8 +15,6 @@
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
@@ -230,10 +227,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._base_builder = base_builder
self._event_id = event_id
- @defer.inlineCallbacks
- def build(self, prev_event_ids, auth_event_ids):
- built_event = yield defer.ensureDeferred(
- self._base_builder.build(prev_event_ids, auth_event_ids)
+ async def build(self, prev_event_ids, auth_event_ids):
+ built_event = await self._base_builder.build(
+ prev_event_ids, auth_event_ids
)
built_event._event_id = self._event_id
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4eb41c46e8..c82cf15bc2 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,21 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
+from tests.unittest import HomeserverTestCase
-class RegistrationStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
+class RegistrationStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.user_id = "@my-user:test"
@@ -35,9 +28,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.pwhash = "{xx1}123456789"
self.device_id = "akgjhdjklgshg"
- @defer.inlineCallbacks
def test_register(self):
- yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEquals(
{
@@ -49,93 +41,81 @@ class RegistrationStoreTestCase(unittest.TestCase):
"consent_version": None,
"consent_server_notice_sent": None,
"appservice_id": None,
- "creation_ts": 1000,
+ "creation_ts": 0,
"user_type": None,
"deactivated": 0,
"shadow_banned": 0,
},
- (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
+ (self.get_success(self.store.get_user_by_id(self.user_id))),
)
- @defer.inlineCallbacks
def test_add_tokens(self):
- yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
- yield defer.ensureDeferred(
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
+ self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
- result = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[1])
- )
+ result = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.assertEqual(result.user_id, self.user_id)
self.assertEqual(result.device_id, self.device_id)
self.assertIsNotNone(result.token_id)
- @defer.inlineCallbacks
def test_user_delete_access_tokens(self):
# add some tokens
- yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
- yield defer.ensureDeferred(
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
+ self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
# now delete some
- yield defer.ensureDeferred(
+ self.get_success(
self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
)
# check they were deleted
- user = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[1])
- )
+ user = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.assertIsNone(user, "access token was not deleted by device_id")
# check the one not associated with the device was not deleted
- user = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[0])
- )
+ user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.assertEqual(self.user_id, user.user_id)
# now delete the rest
- yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
+ self.get_success(self.store.user_delete_access_tokens(self.user_id))
- user = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[0])
- )
+ user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.assertIsNone(user, "access token was not deleted without device_id")
- @defer.inlineCallbacks
def test_is_support_user(self):
TEST_USER = "@test:test"
SUPPORT_USER = "@support:test"
- res = yield defer.ensureDeferred(self.store.is_support_user(None))
+ res = self.get_success(self.store.is_support_user(None))
self.assertFalse(res)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.register_user(user_id=TEST_USER, password_hash=None)
)
- res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
+ res = self.get_success(self.store.is_support_user(TEST_USER))
self.assertFalse(res)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.register_user(
user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
)
)
- res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
+ res = self.get_success(self.store.is_support_user(SUPPORT_USER))
self.assertTrue(res)
- @defer.inlineCallbacks
def test_3pid_inhibit_invalid_validation_session_error(self):
"""Tests that enabling the configuration option to inhibit 3PID errors on
/requestToken also inhibits validation errors caused by an unknown session ID.
@@ -143,30 +123,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
# Check that, with the config setting set to false (the default value), a
# validation error is caused by the unknown session ID.
- try:
- yield defer.ensureDeferred(
- self.store.validate_threepid_session(
- "fake_sid",
- "fake_client_secret",
- "fake_token",
- 0,
- )
- )
- except ThreepidValidationError as e:
- self.assertEquals(e.msg, "Unknown session_id", e)
+ e = self.get_failure(
+ self.store.validate_threepid_session(
+ "fake_sid",
+ "fake_client_secret",
+ "fake_token",
+ 0,
+ ),
+ ThreepidValidationError,
+ )
+ self.assertEquals(e.value.msg, "Unknown session_id", e)
# Set the config setting to true.
self.store._ignore_unknown_session_error = True
# Check that now the validation error is caused by the token not matching.
- try:
- yield defer.ensureDeferred(
- self.store.validate_threepid_session(
- "fake_sid",
- "fake_client_secret",
- "fake_token",
- 0,
- )
- )
- except ThreepidValidationError as e:
- self.assertEquals(e.msg, "Validation token not found or has expired", e)
+ e = self.get_failure(
+ self.store.validate_threepid_session(
+ "fake_sid",
+ "fake_client_secret",
+ "fake_token",
+ 0,
+ ),
+ ThreepidValidationError,
+ )
+ self.assertEquals(e.value.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index bc8400f240..0089d33c93 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,22 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomAlias, RoomID, UserID
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
+from tests.unittest import HomeserverTestCase
-class RoomStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
+class RoomStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
# We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table
self.store = hs.get_datastore()
@@ -37,7 +30,7 @@ class RoomStoreTestCase(unittest.TestCase):
self.alias = RoomAlias.from_string("#a-room-name:test")
self.u_creator = UserID.from_string("@creator:test")
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id=self.u_creator.to_string(),
@@ -46,7 +39,6 @@ class RoomStoreTestCase(unittest.TestCase):
)
)
- @defer.inlineCallbacks
def test_get_room(self):
self.assertDictContainsSubset(
{
@@ -54,16 +46,12 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"is_public": True,
},
- (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
+ (self.get_success(self.store.get_room(self.room.to_string()))),
)
- @defer.inlineCallbacks
def test_get_room_unknown_room(self):
- self.assertIsNone(
- (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
- )
+ self.assertIsNone((self.get_success(self.store.get_room("!uknown:test"))))
- @defer.inlineCallbacks
def test_get_room_with_stats(self):
self.assertDictContainsSubset(
{
@@ -71,29 +59,17 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"public": True,
},
- (
- yield defer.ensureDeferred(
- self.store.get_room_with_stats(self.room.to_string())
- )
- ),
+ (self.get_success(self.store.get_room_with_stats(self.room.to_string()))),
)
- @defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self):
self.assertIsNone(
- (
- yield defer.ensureDeferred(
- self.store.get_room_with_stats("!uknown:test")
- )
- ),
+ (self.get_success(self.store.get_room_with_stats("!uknown:test"))),
)
-class RoomEventsStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = setup_test_homeserver(self.addCleanup)
-
+class RoomEventsStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastore()
@@ -102,7 +78,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abcde:test")
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id="@creator:text",
@@ -111,23 +87,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
)
)
- @defer.inlineCallbacks
def inject_room_event(self, **kwargs):
- yield defer.ensureDeferred(
+ self.get_success(
self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
)
- @defer.inlineCallbacks
def STALE_test_room_name(self):
name = "A-Room-Name"
- yield self.inject_room_event(
+ self.inject_room_event(
etype=EventTypes.Name, name=name, content={"name": name}, depth=1
)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string())
)
@@ -137,15 +111,14 @@ class RoomEventsStoreTestCase(unittest.TestCase):
state[0],
)
- @defer.inlineCallbacks
def STALE_test_room_topic(self):
topic = "A place for things"
- yield self.inject_room_event(
+ self.inject_room_event(
etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string())
)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 2471f1267d..f06b452fa9 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,24 +15,18 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
-import tests.unittest
-import tests.utils
+from tests.unittest import HomeserverTestCase
logger = logging.getLogger(__name__)
-class StateStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
-
+class StateStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_datastore = self.storage.state.stores.state
@@ -44,7 +38,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id="@creator:text",
@@ -53,7 +47,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
)
- @defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content):
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
@@ -66,13 +59,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
},
)
- event, context = yield defer.ensureDeferred(
+ event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- yield defer.ensureDeferred(
- self.storage.persistence.persist_event(event, context)
- )
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -82,16 +73,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2))
- @defer.inlineCallbacks
def test_get_state_groups_ids(self):
- e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, "", {}
- )
- e2 = yield self.inject_state_event(
+ e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+ e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield defer.ensureDeferred(
+ state_group_map = self.get_success(
self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
@@ -101,16 +89,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
)
- @defer.inlineCallbacks
def test_get_state_groups(self):
- e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, "", {}
- )
- e2 = yield self.inject_state_event(
+ e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+ e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield defer.ensureDeferred(
+ state_group_map = self.get_success(
self.storage.state.get_state_groups(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
@@ -118,32 +103,29 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
- @defer.inlineCallbacks
def test_get_state_for_event(self):
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
- e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, "", {}
- )
- e2 = yield self.inject_state_event(
+ e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+ e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- e3 = yield self.inject_state_event(
+ e3 = self.inject_state_event(
self.room,
self.u_alice,
EventTypes.Member,
self.u_alice.to_string(),
{"membership": Membership.JOIN},
)
- e4 = yield self.inject_state_event(
+ e4 = self.inject_state_event(
self.room,
self.u_bob,
EventTypes.Member,
self.u_bob.to_string(),
{"membership": Membership.JOIN},
)
- e5 = yield self.inject_state_event(
+ e5 = self.inject_state_event(
self.room,
self.u_bob,
EventTypes.Member,
@@ -152,9 +134,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
- state = yield defer.ensureDeferred(
- self.storage.state.get_state_for_event(e5.event_id)
- )
+ state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
self.assertIsNotNone(e4)
@@ -170,7 +150,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
)
@@ -179,7 +159,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
@@ -188,7 +168,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
@@ -200,7 +180,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
@@ -220,7 +200,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
@@ -238,17 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
- group_ids = yield defer.ensureDeferred(
+ group_ids = self.get_success(
self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
)
group = list(group_ids.keys())[0]
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -265,10 +242,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -281,10 +255,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with wildcard types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -301,10 +272,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -324,10 +292,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -344,10 +309,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -360,10 +322,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -413,10 +372,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
room_id = self.room.to_string()
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -428,10 +384,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string()
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -444,10 +397,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# wildcard types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -458,10 +408,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -480,10 +427,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -494,10 +438,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -510,10 +451,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -524,10 +462,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index a6f63f4aaf..019c5b7b14 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,10 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase, override_config
ALICE = "@alice:a"
BOB = "@bob:b"
@@ -25,73 +22,52 @@ BOBBY = "@bobby:a"
BELA = "@somenickname:a"
-class UserDirectoryStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup)
- self.store = self.hs.get_datastore()
+class UserDirectoryStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(ALICE, "alice", None)
- )
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(BOB, "bob", None)
- )
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
- )
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(BELA, "Bela", None)
- )
- yield defer.ensureDeferred(
- self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
- )
+ self.get_success(self.store.update_profile_in_user_dir(ALICE, "alice", None))
+ self.get_success(self.store.update_profile_in_user_dir(BOB, "bob", None))
+ self.get_success(self.store.update_profile_in_user_dir(BOBBY, "bobby", None))
+ self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
+ self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
- @defer.inlineCallbacks
def test_search_user_dir(self):
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
- r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
+ r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"]))
self.assertDictEqual(
r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None}
)
- @defer.inlineCallbacks
+ @override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_all_users(self):
- self.hs.config.user_directory_search_all_users = True
- try:
- r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
- self.assertFalse(r["limited"])
- self.assertEqual(2, len(r["results"]))
- self.assertDictEqual(
- r["results"][0],
- {"user_id": BOB, "display_name": "bob", "avatar_url": None},
- )
- self.assertDictEqual(
- r["results"][1],
- {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
- )
- finally:
- self.hs.config.user_directory_search_all_users = False
+ r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(2, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BOB, "display_name": "bob", "avatar_url": None},
+ )
+ self.assertDictEqual(
+ r["results"][1],
+ {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
+ )
- @defer.inlineCallbacks
+ @override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_stop_words(self):
"""Tests that a user can look up another user by searching for the start if its
display name even if that name happens to be a common English word that would
usually be ignored in full text searches.
"""
- self.hs.config.user_directory_search_all_users = True
- try:
- r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10))
- self.assertFalse(r["limited"])
- self.assertEqual(1, len(r["results"]))
- self.assertDictEqual(
- r["results"][0],
- {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
- )
- finally:
- self.hs.config.user_directory_search_all_users = False
+ r = self.get_success(self.store.search_user_dir(ALICE, "be", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+ )
|