diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 331aa13fed..214e722eb3 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -12,11 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse import types
+
from twisted.internet import defer
+import synapse.api.errors
import synapse.handlers.device
+
import synapse.storage
+from synapse import types
from tests import unittest, utils
user1 = "@boris:aaa"
@@ -27,7 +30,7 @@ class DeviceTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(DeviceTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
- self.handler = None # type: device.DeviceHandler
+ self.handler = None # type: synapse.handlers.device.DeviceHandler
self.clock = None # type: utils.MockClock
@defer.inlineCallbacks
@@ -124,6 +127,21 @@ class DeviceTestCase(unittest.TestCase):
}, res)
@defer.inlineCallbacks
+ def test_delete_device(self):
+ yield self._record_users()
+
+ # delete the device
+ yield self.handler.delete_device(user1, "abc")
+
+ # check the device was deleted
+ with self.assertRaises(synapse.api.errors.NotFoundError):
+ yield self.handler.get_device(user1, "abc")
+
+ # we'd like to check the access token was invalidated, but that's a
+ # bit of a PITA.
+
+
+ @defer.inlineCallbacks
def _record_users(self):
# check this works for both devices which have a recorded client_ip,
# and those which don't.
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 3bd7065e32..8ac56a1fb2 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -65,13 +65,16 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.registration_handler.appservice_register = Mock(
return_value=user_id
)
- self.auth_handler.issue_access_token = Mock(return_value=token)
+ self.auth_handler.get_login_tuple_for_user_id = Mock(
+ return_value=(token, "kermits_refresh_token")
+ )
(code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200)
det_data = {
"user_id": user_id,
"access_token": token,
+ "refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname
}
self.assertDictContainsSubset(det_data, result)
@@ -121,7 +124,9 @@ class RegisterRestServletTestCase(unittest.TestCase):
"password": "monkey"
}, None)
self.registration_handler.register = Mock(return_value=(user_id, None))
- self.auth_handler.issue_access_token = Mock(return_value=token)
+ self.auth_handler.get_login_tuple_for_user_id = Mock(
+ return_value=(token, "kermits_refresh_token")
+ )
self.device_handler.check_device_registered = \
Mock(return_value=device_id)
@@ -130,13 +135,14 @@ class RegisterRestServletTestCase(unittest.TestCase):
det_data = {
"user_id": user_id,
"access_token": token,
+ "refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname,
"device_id": device_id,
}
self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result)
- self.auth_handler.issue_access_token.assert_called_once_with(
- user_id, device_id=device_id)
+ self.auth_handler.get_login_tuple_for_user_id(
+ user_id, device_id=device_id, initial_device_display_name=None)
def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False
|