summary refs log tree commit diff
path: root/synapse/handlers/oidc.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/oidc.py')
-rw-r--r--synapse/handlers/oidc.py56
1 files changed, 29 insertions, 27 deletions
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index ee6e41c0e4..eca8f16040 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -72,26 +72,26 @@ _SESSION_COOKIES = [
     (b"oidc_session_no_samesite", b"HttpOnly"),
 ]
 
+
 #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
 #: OpenID.Core sec 3.1.3.3.
-Token = TypedDict(
-    "Token",
-    {
-        "access_token": str,
-        "token_type": str,
-        "id_token": Optional[str],
-        "refresh_token": Optional[str],
-        "expires_in": int,
-        "scope": Optional[str],
-    },
-)
+class Token(TypedDict):
+    access_token: str
+    token_type: str
+    id_token: Optional[str]
+    refresh_token: Optional[str]
+    expires_in: int
+    scope: Optional[str]
+
 
 #: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
 #: there is no real point of doing this in our case.
 JWK = Dict[str, str]
 
+
 #: A JWK Set, as per RFC7517 sec 5.
-JWKS = TypedDict("JWKS", {"keys": List[JWK]})
+class JWKS(TypedDict):
+    keys: List[JWK]
 
 
 class OidcHandler:
@@ -105,9 +105,9 @@ class OidcHandler:
         assert provider_confs
 
         self._token_generator = OidcSessionTokenGenerator(hs)
-        self._providers = {
+        self._providers: Dict[str, "OidcProvider"] = {
             p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
-        }  # type: Dict[str, OidcProvider]
+        }
 
     async def load_metadata(self) -> None:
         """Validate the config and load the metadata from the remote endpoint.
@@ -178,7 +178,7 @@ class OidcHandler:
         # are two.
 
         for cookie_name, _ in _SESSION_COOKIES:
-            session = request.getCookie(cookie_name)  # type: Optional[bytes]
+            session: Optional[bytes] = request.getCookie(cookie_name)
             if session is not None:
                 break
         else:
@@ -255,7 +255,7 @@ class OidcError(Exception):
 
     def __str__(self):
         if self.error_description:
-            return "{}: {}".format(self.error, self.error_description)
+            return f"{self.error}: {self.error_description}"
         return self.error
 
 
@@ -277,7 +277,7 @@ class OidcProvider:
         self._token_generator = token_generator
 
         self._config = provider
-        self._callback_url = hs.config.oidc_callback_url  # type: str
+        self._callback_url: str = hs.config.oidc_callback_url
 
         # Calculate the prefix for OIDC callback paths based on the public_baseurl.
         # We'll insert this into the Path= parameter of any session cookies we set.
@@ -290,7 +290,7 @@ class OidcProvider:
         self._scopes = provider.scopes
         self._user_profile_method = provider.user_profile_method
 
-        client_secret = None  # type: Union[None, str, JwtClientSecret]
+        client_secret: Optional[Union[str, JwtClientSecret]] = None
         if provider.client_secret:
             client_secret = provider.client_secret
         elif provider.client_secret_jwt_key:
@@ -305,7 +305,7 @@ class OidcProvider:
             provider.client_id,
             client_secret,
             provider.client_auth_method,
-        )  # type: ClientAuth
+        )
         self._client_auth_method = provider.client_auth_method
 
         # cache of metadata for the identity provider (endpoint uris, mostly). This is
@@ -324,7 +324,7 @@ class OidcProvider:
         self._allow_existing_users = provider.allow_existing_users
 
         self._http_client = hs.get_proxied_http_client()
-        self._server_name = hs.config.server_name  # type: str
+        self._server_name: str = hs.config.server_name
 
         # identifier for the external_ids table
         self.idp_id = provider.idp_id
@@ -639,7 +639,7 @@ class OidcProvider:
             )
             logger.warning(description)
             # Body was still valid JSON. Might be useful to log it for debugging.
-            logger.warning("Code exchange response: {resp!r}".format(resp=resp))
+            logger.warning("Code exchange response: %r", resp)
             raise OidcError("server_error", description)
 
         return resp
@@ -1217,10 +1217,12 @@ class OidcSessionData:
     ui_auth_session_id = attr.ib(type=str)
 
 
-UserAttributeDict = TypedDict(
-    "UserAttributeDict",
-    {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
-)
+class UserAttributeDict(TypedDict):
+    localpart: Optional[str]
+    display_name: Optional[str]
+    emails: List[str]
+
+
 C = TypeVar("C")
 
 
@@ -1381,7 +1383,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
         if display_name == "":
             display_name = None
 
-        emails = []  # type: List[str]
+        emails: List[str] = []
         email = render_template_field(self._config.email_template)
         if email:
             emails.append(email)
@@ -1391,7 +1393,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
         )
 
     async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
-        extras = {}  # type: Dict[str, str]
+        extras: Dict[str, str] = {}
         for key, template in self._config.extra_attributes.items():
             try:
                 extras[key] = template.render(user=userinfo).strip()