summary refs log tree commit diff
path: root/synapse/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client')
-rw-r--r--synapse/rest/client/v1/logout.py27
1 files changed, 25 insertions, 2 deletions
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 6add754782..ca49955935 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -16,6 +16,7 @@
 from twisted.internet import defer
 
 from synapse.api.auth import get_access_token_from_request
+from synapse.api.errors import AuthError
 
 from .base import ClientV1RestServlet, client_path_patterns
 
@@ -30,15 +31,30 @@ class LogoutRestServlet(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(LogoutRestServlet, self).__init__(hs)
+        self._auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
+        self._device_handler = hs.get_device_handler()
 
     def on_OPTIONS(self, request):
         return (200, {})
 
     @defer.inlineCallbacks
     def on_POST(self, request):
-        access_token = get_access_token_from_request(request)
-        yield self._auth_handler.delete_access_token(access_token)
+        try:
+            requester = yield self.auth.get_user_by_req(request)
+        except AuthError:
+            # this implies the access token has already been deleted.
+            pass
+        else:
+            if requester.device_id is None:
+                # the acccess token wasn't associated with a device.
+                # Just delete the access token
+                access_token = get_access_token_from_request(request)
+                yield self._auth_handler.delete_access_token(access_token)
+            else:
+                yield self._device_handler.delete_device(
+                    requester.user.to_string(), requester.device_id)
+
         defer.returnValue((200, {}))
 
 
@@ -49,6 +65,7 @@ class LogoutAllRestServlet(ClientV1RestServlet):
         super(LogoutAllRestServlet, self).__init__(hs)
         self.auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
+        self._device_handler = hs.get_device_handler()
 
     def on_OPTIONS(self, request):
         return (200, {})
@@ -57,6 +74,12 @@ class LogoutAllRestServlet(ClientV1RestServlet):
     def on_POST(self, request):
         requester = yield self.auth.get_user_by_req(request)
         user_id = requester.user.to_string()
+
+        # first delete all of the user's devices
+        yield self._device_handler.delete_all_devices_for_user(user_id)
+
+        # .. and then delete any access tokens which weren't associated with
+        # devices.
         yield self._auth_handler.delete_access_tokens_for_user(user_id)
         defer.returnValue((200, {}))