diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 331aa13fed..85a970a6c9 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -12,11 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse import types
+
from twisted.internet import defer
+import synapse.api.errors
import synapse.handlers.device
+
import synapse.storage
+from synapse import types
from tests import unittest, utils
user1 = "@boris:aaa"
@@ -27,7 +30,7 @@ class DeviceTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(DeviceTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
- self.handler = None # type: device.DeviceHandler
+ self.handler = None # type: synapse.handlers.device.DeviceHandler
self.clock = None # type: utils.MockClock
@defer.inlineCallbacks
@@ -124,6 +127,37 @@ class DeviceTestCase(unittest.TestCase):
}, res)
@defer.inlineCallbacks
+ def test_delete_device(self):
+ yield self._record_users()
+
+ # delete the device
+ yield self.handler.delete_device(user1, "abc")
+
+ # check the device was deleted
+ with self.assertRaises(synapse.api.errors.NotFoundError):
+ yield self.handler.get_device(user1, "abc")
+
+ # we'd like to check the access token was invalidated, but that's a
+ # bit of a PITA.
+
+ @defer.inlineCallbacks
+ def test_update_device(self):
+ yield self._record_users()
+
+ update = {"display_name": "new display"}
+ yield self.handler.update_device(user1, "abc", update)
+
+ res = yield self.handler.get_device(user1, "abc")
+ self.assertEqual(res["display_name"], "new display")
+
+ @defer.inlineCallbacks
+ def test_update_unknown_device(self):
+ update = {"display_name": "new_display"}
+ with self.assertRaises(synapse.api.errors.NotFoundError):
+ yield self.handler.update_device("user_id", "unknown_device_id",
+ update)
+
+ @defer.inlineCallbacks
def _record_users(self):
# check this works for both devices which have a recorded client_ip,
# and those which don't.
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 3bd7065e32..8ac56a1fb2 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -65,13 +65,16 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.registration_handler.appservice_register = Mock(
return_value=user_id
)
- self.auth_handler.issue_access_token = Mock(return_value=token)
+ self.auth_handler.get_login_tuple_for_user_id = Mock(
+ return_value=(token, "kermits_refresh_token")
+ )
(code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200)
det_data = {
"user_id": user_id,
"access_token": token,
+ "refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname
}
self.assertDictContainsSubset(det_data, result)
@@ -121,7 +124,9 @@ class RegisterRestServletTestCase(unittest.TestCase):
"password": "monkey"
}, None)
self.registration_handler.register = Mock(return_value=(user_id, None))
- self.auth_handler.issue_access_token = Mock(return_value=token)
+ self.auth_handler.get_login_tuple_for_user_id = Mock(
+ return_value=(token, "kermits_refresh_token")
+ )
self.device_handler.check_device_registered = \
Mock(return_value=device_id)
@@ -130,13 +135,14 @@ class RegisterRestServletTestCase(unittest.TestCase):
det_data = {
"user_id": user_id,
"access_token": token,
+ "refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname,
"device_id": device_id,
}
self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result)
- self.auth_handler.issue_access_token.assert_called_once_with(
- user_id, device_id=device_id)
+ self.auth_handler.get_login_tuple_for_user_id(
+ user_id, device_id=device_id, initial_device_display_name=None)
def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index a6ce993375..f8725acea0 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -15,6 +15,7 @@
from twisted.internet import defer
+import synapse.api.errors
import tests.unittest
import tests.utils
@@ -67,3 +68,38 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
"device_id": "device2",
"display_name": "display_name 2",
}, res["device2"])
+
+ @defer.inlineCallbacks
+ def test_update_device(self):
+ yield self.store.store_device(
+ "user_id", "device_id", "display_name 1"
+ )
+
+ res = yield self.store.get_device("user_id", "device_id")
+ self.assertEqual("display_name 1", res["display_name"])
+
+ # do a no-op first
+ yield self.store.update_device(
+ "user_id", "device_id",
+ )
+ res = yield self.store.get_device("user_id", "device_id")
+ self.assertEqual("display_name 1", res["display_name"])
+
+ # do the update
+ yield self.store.update_device(
+ "user_id", "device_id",
+ new_display_name="display_name 2",
+ )
+
+ # check it worked
+ res = yield self.store.get_device("user_id", "device_id")
+ self.assertEqual("display_name 2", res["display_name"])
+
+ @defer.inlineCallbacks
+ def test_update_unknown_device(self):
+ with self.assertRaises(synapse.api.errors.StoreError) as cm:
+ yield self.store.update_device(
+ "user_id", "unknown_device_id",
+ new_display_name="display_name 2",
+ )
+ self.assertEqual(404, cm.exception.code)
|