summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-05-07 09:30:45 -0400
committerGitHub <noreply@github.com>2020-05-07 09:30:45 -0400
commit22246919e38ef47780b080a8109728ffeba77f69 (patch)
treefba244f0258333b4f444ae18c0ec6882aacaeed4 /synapse
parentSupport any process writing to cache invalidation stream. (#7436) (diff)
downloadsynapse-22246919e38ef47780b080a8109728ffeba77f69.tar.xz
Add more type hints to SAML handler. (#7445)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/saml_handler.py28
1 files changed, 16 insertions, 12 deletions
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 96f2dd36ad..e7015c704f 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import re
-from typing import Optional, Tuple
+from typing import Callable, Dict, Optional, Set, Tuple
 
 import attr
 import saml2
@@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.http.server import finish_request
 from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
 from synapse.module_api.errors import RedirectException
 from synapse.types import (
@@ -81,17 +82,19 @@ class SamlHandler:
 
         self._error_html_content = hs.config.saml2_error_html_content
 
-    def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None):
+    def handle_redirect_request(
+        self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
+    ) -> bytes:
         """Handle an incoming request to /login/sso/redirect
 
         Args:
-            client_redirect_url (bytes): the URL that we should redirect the
+            client_redirect_url: the URL that we should redirect the
                 client to when everything is done
-            ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
         Returns:
-            bytes: URL to redirect to
+            URL to redirect to
         """
         reqid, info = self._saml_client.prepare_for_authenticate(
             relay_state=client_redirect_url
@@ -109,15 +112,15 @@ class SamlHandler:
         # this shouldn't happen!
         raise Exception("prepare_for_authenticate didn't return a Location header")
 
-    async def handle_saml_response(self, request):
+    async def handle_saml_response(self, request: SynapseRequest) -> None:
         """Handle an incoming request to /_matrix/saml2/authn_response
 
         Args:
-            request (SynapseRequest): the incoming request from the browser. We'll
+            request: the incoming request from the browser. We'll
                 respond to it with a redirect.
 
         Returns:
-            Deferred[none]: Completes once we have handled the request.
+            Completes once we have handled the request.
         """
         resp_bytes = parse_string(request, "SAMLResponse", required=True)
         relay_state = parse_string(request, "RelayState", required=True)
@@ -310,6 +313,7 @@ DOT_REPLACE_PATTERN = re.compile(
 
 
 def dot_replace_for_mxid(username: str) -> str:
+    """Replace any characters which are not allowed in Matrix IDs with a dot."""
     username = username.lower()
     username = DOT_REPLACE_PATTERN.sub(".", username)
 
@@ -321,7 +325,7 @@ def dot_replace_for_mxid(username: str) -> str:
 MXID_MAPPER_MAP = {
     "hexencode": map_username_to_mxid_localpart,
     "dotreplace": dot_replace_for_mxid,
-}
+}  # type: Dict[str, Callable[[str], str]]
 
 
 @attr.s
@@ -349,7 +353,7 @@ class DefaultSamlMappingProvider(object):
 
     def get_remote_user_id(
         self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
-    ):
+    ) -> str:
         """Extracts the remote user id from the SAML response"""
         try:
             return saml_response.ava["uid"][0]
@@ -428,14 +432,14 @@ class DefaultSamlMappingProvider(object):
         return SamlConfig(mxid_source_attribute, mxid_mapper)
 
     @staticmethod
-    def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+    def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
         """Returns the required attributes of a SAML
 
         Args:
             config: A SamlConfig object containing configuration params for this provider
 
         Returns:
-            tuple[set,set]: The first set equates to the saml auth response
+            The first set equates to the saml auth response
                 attributes that are required for the module to function, whereas the
                 second set consists of those attributes which can be used if
                 available, but are not necessary