diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 8e7db2c4ec..f03807c8f9 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -12,17 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Collection, List, Tuple
+
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.api.errors
from synapse.api.constants import EduTypes
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class DeviceStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- def add_device_change(self, user_id, device_ids, host):
+ def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
"""Add a device list change for the given device to
`device_lists_outbound_pokes` table.
"""
@@ -44,12 +51,13 @@ class DeviceStoreTestCase(HomeserverTestCase):
)
)
- def test_store_new_device(self):
+ def test_store_new_device(self) -> None:
self.get_success(
self.store.store_device("user_id", "device_id", "display_name")
)
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -59,7 +67,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
res,
)
- def test_get_devices_by_user(self):
+ def test_get_devices_by_user(self) -> None:
self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
@@ -89,7 +97,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
res["device2"],
)
- def test_count_devices_by_users(self):
+ def test_count_devices_by_users(self) -> None:
self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
@@ -114,7 +122,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
)
self.assertEqual(3, res)
- def test_get_device_updates_by_remote(self):
+ def test_get_device_updates_by_remote(self) -> None:
device_ids = ["device_id1", "device_id2"]
# Add two device updates with sequential `stream_id`s
@@ -128,7 +136,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)
- def test_get_device_updates_by_remote_can_limit_properly(self):
+ def test_get_device_updates_by_remote_can_limit_properly(self) -> None:
"""
Tests that `get_device_updates_by_remote` returns an appropriate
stream_id to resume fetching from (without skipping any results).
@@ -280,7 +288,11 @@ class DeviceStoreTestCase(HomeserverTestCase):
)
self.assertEqual(device_updates, [])
- def _check_devices_in_updates(self, expected_device_ids, device_updates):
+ def _check_devices_in_updates(
+ self,
+ expected_device_ids: Collection[str],
+ device_updates: List[Tuple[str, JsonDict]],
+ ) -> None:
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
@@ -289,17 +301,19 @@ class DeviceStoreTestCase(HomeserverTestCase):
}
self.assertEqual(received_device_ids, set(expected_device_ids))
- def test_update_device(self):
+ def test_update_device(self) -> None:
self.get_success(
self.store.store_device("user_id", "device_id", "display_name 1")
)
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
self.get_success(self.store.update_device("user_id", "device_id"))
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertEqual("display_name 1", res["display_name"])
# do the update
@@ -311,9 +325,10 @@ class DeviceStoreTestCase(HomeserverTestCase):
# check it worked
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertEqual("display_name 2", res["display_name"])
- def test_update_unknown_device(self):
+ def test_update_unknown_device(self) -> None:
exc = self.get_failure(
self.store.update_device(
"user_id", "unknown_device_id", new_display_name="display_name 2"
|