diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 1197158fdc..5022808ea9 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -137,8 +137,8 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
- self._auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__(hs)
+ self._deactivate_account_handler = hs.get_deactivate_account_handler()
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
@@ -149,7 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
- yield self._auth_handler.deactivate_account(target_user_id)
+ yield self._deactivate_account_handler.deactivate_account(target_user_id)
defer.returnValue((200, {}))
@@ -309,7 +309,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
super(ResetPasswordRestServlet, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
- self.auth_handler = hs.get_auth_handler()
+ self._set_password_handler = hs.get_set_password_handler()
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
@@ -330,7 +330,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
logger.info("new_password: %r", new_password)
- yield self.auth_handler.set_password(
+ yield self._set_password_handler.set_password(
target_user_id, new_password, requester
)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 6add754782..ca49955935 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -16,6 +16,7 @@
from twisted.internet import defer
from synapse.api.auth import get_access_token_from_request
+from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns
@@ -30,15 +31,30 @@ class LogoutRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs)
+ self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request):
return (200, {})
@defer.inlineCallbacks
def on_POST(self, request):
- access_token = get_access_token_from_request(request)
- yield self._auth_handler.delete_access_token(access_token)
+ try:
+ requester = yield self.auth.get_user_by_req(request)
+ except AuthError:
+ # this implies the access token has already been deleted.
+ pass
+ else:
+ if requester.device_id is None:
+ # the acccess token wasn't associated with a device.
+ # Just delete the access token
+ access_token = get_access_token_from_request(request)
+ yield self._auth_handler.delete_access_token(access_token)
+ else:
+ yield self._device_handler.delete_device(
+ requester.user.to_string(), requester.device_id)
+
defer.returnValue((200, {}))
@@ -49,6 +65,7 @@ class LogoutAllRestServlet(ClientV1RestServlet):
super(LogoutAllRestServlet, self).__init__(hs)
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request):
return (200, {})
@@ -57,6 +74,12 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
+
+ # first delete all of the user's devices
+ yield self._device_handler.delete_all_devices_for_user(user_id)
+
+ # .. and then delete any access tokens which weren't associated with
+ # devices.
yield self._auth_handler.delete_access_tokens_for_user(user_id)
defer.returnValue((200, {}))
|