summary refs log tree commit diff
path: root/tests/test_utils
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_utils')
-rw-r--r--tests/test_utils/oidc.py27
1 files changed, 25 insertions, 2 deletions
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index de134bbc89..1461d23ee8 100644
--- a/tests/test_utils/oidc.py
+++ b/tests/test_utils/oidc.py
@@ -51,6 +51,8 @@ class FakeOidcServer:
     get_userinfo_handler: Mock
     post_token_handler: Mock
 
+    sid_counter: int = 0
+
     def __init__(self, clock: Clock, issuer: str):
         from authlib.jose import ECKey, KeySet
 
@@ -146,7 +148,7 @@ class FakeOidcServer:
         return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8")
 
     def generate_id_token(self, grant: FakeAuthorizationGrant) -> str:
-        now = self._clock.time()
+        now = int(self._clock.time())
         id_token = {
             **grant.userinfo,
             "iss": self.issuer,
@@ -166,6 +168,26 @@ class FakeOidcServer:
 
         return self._sign(id_token)
 
+    def generate_logout_token(self, grant: FakeAuthorizationGrant) -> str:
+        now = int(self._clock.time())
+        logout_token = {
+            "iss": self.issuer,
+            "aud": grant.client_id,
+            "iat": now,
+            "jti": random_string(10),
+            "events": {
+                "http://schemas.openid.net/event/backchannel-logout": {},
+            },
+        }
+
+        if grant.sid is not None:
+            logout_token["sid"] = grant.sid
+
+        if "sub" in grant.userinfo:
+            logout_token["sub"] = grant.userinfo["sub"]
+
+        return self._sign(logout_token)
+
     def id_token_override(self, overrides: dict):
         """Temporarily patch the ID token generated by the token endpoint."""
         return patch.object(self, "_id_token_overrides", overrides)
@@ -183,7 +205,8 @@ class FakeOidcServer:
         code = random_string(10)
         sid = None
         if with_sid:
-            sid = random_string(10)
+            sid = str(self.sid_counter)
+            self.sid_counter += 1
 
         grant = FakeAuthorizationGrant(
             userinfo=userinfo,