summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_password_providers.py28
1 files changed, 27 insertions, 1 deletions
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 08e9730d4d..2add72b28a 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -22,7 +22,7 @@ from twisted.internet import defer
 import synapse
 from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login
+from synapse.rest.client import devices, login, logout
 from synapse.types import JsonDict
 
 from tests import unittest
@@ -155,6 +155,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         synapse.rest.admin.register_servlets,
         login.register_servlets,
         devices.register_servlets,
+        logout.register_servlets,
     ]
 
     def setUp(self):
@@ -719,6 +720,31 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         channel = self._send_password_login("localuser", "localpass")
         self.assertEqual(channel.code, 400, channel.result)
 
+    def test_on_logged_out(self):
+        """Tests that the on_logged_out callback is called when the user logs out."""
+        self.register_user("rin", "password")
+        tok = self.login("rin", "password")
+
+        self.called = False
+
+        async def on_logged_out(user_id, device_id, access_token):
+            self.called = True
+
+        on_logged_out = Mock(side_effect=on_logged_out)
+        self.hs.get_password_auth_provider().on_logged_out_callbacks.append(
+            on_logged_out
+        )
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/logout",
+            {},
+            access_token=tok,
+        )
+        self.assertEqual(channel.code, 200)
+        on_logged_out.assert_called_once()
+        self.assertTrue(self.called)
+
     def _get_login_flows(self) -> JsonDict:
         channel = self.make_request("GET", "/_matrix/client/r0/login")
         self.assertEqual(channel.code, 200, channel.result)