summary refs log tree commit diff
path: root/tests/storage/test_devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_devices.py')
-rw-r--r--tests/storage/test_devices.py80
1 files changed, 32 insertions, 48 deletions
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index dabc1c5f09..ef4cf8d0f1 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,32 +13,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 import synapse.api.errors
 
-import tests.unittest
-import tests.utils
-
-
-class DeviceStoreTestCase(tests.unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.store = None  # type: synapse.storage.DataStore
+from tests.unittest import HomeserverTestCase
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
 
+class DeviceStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
-    @defer.inlineCallbacks
     def test_store_new_device(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device_id", "display_name")
         )
 
-        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+        res = self.get_success(self.store.get_device("user_id", "device_id"))
         self.assertDictContainsSubset(
             {
                 "user_id": "user_id",
@@ -48,19 +37,18 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
             res,
         )
 
-    @defer.inlineCallbacks
     def test_get_devices_by_user(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device1", "display_name 1")
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device2", "display_name 2")
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id2", "device3", "display_name 3")
         )
 
-        res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
+        res = self.get_success(self.store.get_devices_by_user("user_id"))
         self.assertEqual(2, len(res.keys()))
         self.assertDictContainsSubset(
             {
@@ -79,43 +67,41 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
             res["device2"],
         )
 
-    @defer.inlineCallbacks
     def test_count_devices_by_users(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device1", "display_name 1")
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device2", "display_name 2")
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id2", "device3", "display_name 3")
         )
 
-        res = yield defer.ensureDeferred(self.store.count_devices_by_users())
+        res = self.get_success(self.store.count_devices_by_users())
         self.assertEqual(0, res)
 
-        res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
+        res = self.get_success(self.store.count_devices_by_users(["unknown"]))
         self.assertEqual(0, res)
 
-        res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
+        res = self.get_success(self.store.count_devices_by_users(["user_id"]))
         self.assertEqual(2, res)
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.count_devices_by_users(["user_id", "user_id2"])
         )
         self.assertEqual(3, res)
 
-    @defer.inlineCallbacks
     def test_get_device_updates_by_remote(self):
         device_ids = ["device_id1", "device_id2"]
 
         # Add two device updates with a single stream_id
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
         )
 
         # Get all device updates ever meant for this remote
-        now_stream_id, device_updates = yield defer.ensureDeferred(
+        now_stream_id, device_updates = self.get_success(
             self.store.get_device_updates_by_remote("somehost", -1, limit=100)
         )
 
@@ -131,37 +117,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         }
         self.assertEqual(received_device_ids, set(expected_device_ids))
 
-    @defer.inlineCallbacks
     def test_update_device(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device_id", "display_name 1")
         )
 
-        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+        res = self.get_success(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 1", res["display_name"])
 
         # do a no-op first
-        yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
-        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+        self.get_success(self.store.update_device("user_id", "device_id"))
+        res = self.get_success(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 1", res["display_name"])
 
         # do the update
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.update_device(
                 "user_id", "device_id", new_display_name="display_name 2"
             )
         )
 
         # check it worked
-        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+        res = self.get_success(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 2", res["display_name"])
 
-    @defer.inlineCallbacks
     def test_update_unknown_device(self):
-        with self.assertRaises(synapse.api.errors.StoreError) as cm:
-            yield defer.ensureDeferred(
-                self.store.update_device(
-                    "user_id", "unknown_device_id", new_display_name="display_name 2"
-                )
-            )
-        self.assertEqual(404, cm.exception.code)
+        exc = self.get_failure(
+            self.store.update_device(
+                "user_id", "unknown_device_id", new_display_name="display_name 2"
+            ),
+            synapse.api.errors.StoreError,
+        )
+        self.assertEqual(404, exc.value.code)