summary refs log tree commit diff
path: root/synapse/rest/client/v1/logout.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v1/logout.py')
-rw-r--r--synapse/rest/client/v1/logout.py42
1 files changed, 33 insertions, 9 deletions
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 1358d0acab..430c692336 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -13,15 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
+
 from twisted.internet import defer
 
-from synapse.api.auth import get_access_token_from_request
+from synapse.api.errors import AuthError
 
 from .base import ClientV1RestServlet, client_path_patterns
 
-import logging
-
-
 logger = logging.getLogger(__name__)
 
 
@@ -30,15 +29,33 @@ class LogoutRestServlet(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(LogoutRestServlet, self).__init__(hs)
-        self.store = hs.get_datastore()
+        self._auth = hs.get_auth()
+        self._auth_handler = hs.get_auth_handler()
+        self._device_handler = hs.get_device_handler()
 
     def on_OPTIONS(self, request):
         return (200, {})
 
     @defer.inlineCallbacks
     def on_POST(self, request):
-        access_token = get_access_token_from_request(request)
-        yield self.store.delete_access_token(access_token)
+        try:
+            requester = yield self.auth.get_user_by_req(request)
+        except AuthError:
+            # this implies the access token has already been deleted.
+            defer.returnValue((401, {
+                "errcode": "M_UNKNOWN_TOKEN",
+                "error": "Access Token unknown or expired"
+            }))
+        else:
+            if requester.device_id is None:
+                # the acccess token wasn't associated with a device.
+                # Just delete the access token
+                access_token = self._auth.get_access_token_from_request(request)
+                yield self._auth_handler.delete_access_token(access_token)
+            else:
+                yield self._device_handler.delete_device(
+                    requester.user.to_string(), requester.device_id)
+
         defer.returnValue((200, {}))
 
 
@@ -47,8 +64,9 @@ class LogoutAllRestServlet(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(LogoutAllRestServlet, self).__init__(hs)
-        self.store = hs.get_datastore()
         self.auth = hs.get_auth()
+        self._auth_handler = hs.get_auth_handler()
+        self._device_handler = hs.get_device_handler()
 
     def on_OPTIONS(self, request):
         return (200, {})
@@ -57,7 +75,13 @@ class LogoutAllRestServlet(ClientV1RestServlet):
     def on_POST(self, request):
         requester = yield self.auth.get_user_by_req(request)
         user_id = requester.user.to_string()
-        yield self.store.user_delete_access_tokens(user_id)
+
+        # first delete all of the user's devices
+        yield self._device_handler.delete_all_devices_for_user(user_id)
+
+        # .. and then delete any access tokens which weren't associated with
+        # devices.
+        yield self._auth_handler.delete_access_tokens_for_user(user_id)
         defer.returnValue((200, {}))