summary refs log tree commit diff
path: root/tests/storage/test_devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_devices.py')
-rw-r--r--tests/storage/test_devices.py35
1 files changed, 25 insertions, 10 deletions
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 8e7db2c4ec..f03807c8f9 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -12,17 +12,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Collection, List, Tuple
+
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.api.errors
 from synapse.api.constants import EduTypes
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 
 
 class DeviceStoreTestCase(HomeserverTestCase):
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
 
-    def add_device_change(self, user_id, device_ids, host):
+    def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
         """Add a device list change for the given device to
         `device_lists_outbound_pokes` table.
         """
@@ -44,12 +51,13 @@ class DeviceStoreTestCase(HomeserverTestCase):
                 )
             )
 
-    def test_store_new_device(self):
+    def test_store_new_device(self) -> None:
         self.get_success(
             self.store.store_device("user_id", "device_id", "display_name")
         )
 
         res = self.get_success(self.store.get_device("user_id", "device_id"))
+        assert res is not None
         self.assertDictContainsSubset(
             {
                 "user_id": "user_id",
@@ -59,7 +67,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
             res,
         )
 
-    def test_get_devices_by_user(self):
+    def test_get_devices_by_user(self) -> None:
         self.get_success(
             self.store.store_device("user_id", "device1", "display_name 1")
         )
@@ -89,7 +97,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
             res["device2"],
         )
 
-    def test_count_devices_by_users(self):
+    def test_count_devices_by_users(self) -> None:
         self.get_success(
             self.store.store_device("user_id", "device1", "display_name 1")
         )
@@ -114,7 +122,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
         )
         self.assertEqual(3, res)
 
-    def test_get_device_updates_by_remote(self):
+    def test_get_device_updates_by_remote(self) -> None:
         device_ids = ["device_id1", "device_id2"]
 
         # Add two device updates with sequential `stream_id`s
@@ -128,7 +136,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
         # Check original device_ids are contained within these updates
         self._check_devices_in_updates(device_ids, device_updates)
 
-    def test_get_device_updates_by_remote_can_limit_properly(self):
+    def test_get_device_updates_by_remote_can_limit_properly(self) -> None:
         """
         Tests that `get_device_updates_by_remote` returns an appropriate
         stream_id to resume fetching from (without skipping any results).
@@ -280,7 +288,11 @@ class DeviceStoreTestCase(HomeserverTestCase):
         )
         self.assertEqual(device_updates, [])
 
-    def _check_devices_in_updates(self, expected_device_ids, device_updates):
+    def _check_devices_in_updates(
+        self,
+        expected_device_ids: Collection[str],
+        device_updates: List[Tuple[str, JsonDict]],
+    ) -> None:
         """Check that an specific device ids exist in a list of device update EDUs"""
         self.assertEqual(len(device_updates), len(expected_device_ids))
 
@@ -289,17 +301,19 @@ class DeviceStoreTestCase(HomeserverTestCase):
         }
         self.assertEqual(received_device_ids, set(expected_device_ids))
 
-    def test_update_device(self):
+    def test_update_device(self) -> None:
         self.get_success(
             self.store.store_device("user_id", "device_id", "display_name 1")
         )
 
         res = self.get_success(self.store.get_device("user_id", "device_id"))
+        assert res is not None
         self.assertEqual("display_name 1", res["display_name"])
 
         # do a no-op first
         self.get_success(self.store.update_device("user_id", "device_id"))
         res = self.get_success(self.store.get_device("user_id", "device_id"))
+        assert res is not None
         self.assertEqual("display_name 1", res["display_name"])
 
         # do the update
@@ -311,9 +325,10 @@ class DeviceStoreTestCase(HomeserverTestCase):
 
         # check it worked
         res = self.get_success(self.store.get_device("user_id", "device_id"))
+        assert res is not None
         self.assertEqual("display_name 2", res["display_name"])
 
-    def test_update_unknown_device(self):
+    def test_update_unknown_device(self) -> None:
         exc = self.get_failure(
             self.store.update_device(
                 "user_id", "unknown_device_id", new_display_name="display_name 2"