diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 1358d0acab..430c692336 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from synapse.api.auth import get_access_token_from_request
+from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns
-import logging
-
-
logger = logging.getLogger(__name__)
@@ -30,15 +29,33 @@ class LogoutRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs)
- self.store = hs.get_datastore()
+ self._auth = hs.get_auth()
+ self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request):
return (200, {})
@defer.inlineCallbacks
def on_POST(self, request):
- access_token = get_access_token_from_request(request)
- yield self.store.delete_access_token(access_token)
+ try:
+ requester = yield self.auth.get_user_by_req(request)
+ except AuthError:
+ # this implies the access token has already been deleted.
+ defer.returnValue((401, {
+ "errcode": "M_UNKNOWN_TOKEN",
+ "error": "Access Token unknown or expired"
+ }))
+ else:
+ if requester.device_id is None:
+ # the acccess token wasn't associated with a device.
+ # Just delete the access token
+ access_token = self._auth.get_access_token_from_request(request)
+ yield self._auth_handler.delete_access_token(access_token)
+ else:
+ yield self._device_handler.delete_device(
+ requester.user.to_string(), requester.device_id)
+
defer.returnValue((200, {}))
@@ -47,8 +64,9 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LogoutAllRestServlet, self).__init__(hs)
- self.store = hs.get_datastore()
self.auth = hs.get_auth()
+ self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request):
return (200, {})
@@ -57,7 +75,13 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
- yield self.store.user_delete_access_tokens(user_id)
+
+ # first delete all of the user's devices
+ yield self._device_handler.delete_all_devices_for_user(user_id)
+
+ # .. and then delete any access tokens which weren't associated with
+ # devices.
+ yield self._auth_handler.delete_access_tokens_for_user(user_id)
defer.returnValue((200, {}))
|