diff options
Diffstat (limited to 'tests/storage/test_devices.py')
-rw-r--r-- | tests/storage/test_devices.py | 35 |
1 files changed, 25 insertions, 10 deletions
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 8e7db2c4ec..f03807c8f9 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -12,17 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Collection, List, Tuple + +from twisted.test.proto_helpers import MemoryReactor + import synapse.api.errors from synapse.api.constants import EduTypes +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase class DeviceStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - def add_device_change(self, user_id, device_ids, host): + def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None: """Add a device list change for the given device to `device_lists_outbound_pokes` table. """ @@ -44,12 +51,13 @@ class DeviceStoreTestCase(HomeserverTestCase): ) ) - def test_store_new_device(self): + def test_store_new_device(self) -> None: self.get_success( self.store.store_device("user_id", "device_id", "display_name") ) res = self.get_success(self.store.get_device("user_id", "device_id")) + assert res is not None self.assertDictContainsSubset( { "user_id": "user_id", @@ -59,7 +67,7 @@ class DeviceStoreTestCase(HomeserverTestCase): res, ) - def test_get_devices_by_user(self): + def test_get_devices_by_user(self) -> None: self.get_success( self.store.store_device("user_id", "device1", "display_name 1") ) @@ -89,7 +97,7 @@ class DeviceStoreTestCase(HomeserverTestCase): res["device2"], ) - def test_count_devices_by_users(self): + def test_count_devices_by_users(self) -> None: self.get_success( self.store.store_device("user_id", "device1", "display_name 1") ) @@ -114,7 +122,7 @@ class DeviceStoreTestCase(HomeserverTestCase): ) self.assertEqual(3, res) - def test_get_device_updates_by_remote(self): + def test_get_device_updates_by_remote(self) -> None: device_ids = ["device_id1", "device_id2"] # Add two device updates with sequential `stream_id`s @@ -128,7 +136,7 @@ class DeviceStoreTestCase(HomeserverTestCase): # Check original device_ids are contained within these updates self._check_devices_in_updates(device_ids, device_updates) - def test_get_device_updates_by_remote_can_limit_properly(self): + def test_get_device_updates_by_remote_can_limit_properly(self) -> None: """ Tests that `get_device_updates_by_remote` returns an appropriate stream_id to resume fetching from (without skipping any results). @@ -280,7 +288,11 @@ class DeviceStoreTestCase(HomeserverTestCase): ) self.assertEqual(device_updates, []) - def _check_devices_in_updates(self, expected_device_ids, device_updates): + def _check_devices_in_updates( + self, + expected_device_ids: Collection[str], + device_updates: List[Tuple[str, JsonDict]], + ) -> None: """Check that an specific device ids exist in a list of device update EDUs""" self.assertEqual(len(device_updates), len(expected_device_ids)) @@ -289,17 +301,19 @@ class DeviceStoreTestCase(HomeserverTestCase): } self.assertEqual(received_device_ids, set(expected_device_ids)) - def test_update_device(self): + def test_update_device(self) -> None: self.get_success( self.store.store_device("user_id", "device_id", "display_name 1") ) res = self.get_success(self.store.get_device("user_id", "device_id")) + assert res is not None self.assertEqual("display_name 1", res["display_name"]) # do a no-op first self.get_success(self.store.update_device("user_id", "device_id")) res = self.get_success(self.store.get_device("user_id", "device_id")) + assert res is not None self.assertEqual("display_name 1", res["display_name"]) # do the update @@ -311,9 +325,10 @@ class DeviceStoreTestCase(HomeserverTestCase): # check it worked res = self.get_success(self.store.get_device("user_id", "device_id")) + assert res is not None self.assertEqual("display_name 2", res["display_name"]) - def test_update_unknown_device(self): + def test_update_unknown_device(self) -> None: exc = self.get_failure( self.store.update_device( "user_id", "unknown_device_id", new_display_name="display_name 2" |