summary refs log tree commit diff
path: root/tests/handlers/test_saml.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_saml.py')
-rw-r--r--tests/handlers/test_saml.py33
1 files changed, 22 insertions, 11 deletions
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index a0f84e2940..9b1b8b9f13 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,7 +12,7 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Set, Tuple
 from unittest.mock import Mock
 
 import attr
@@ -20,7 +20,9 @@ import attr
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import RedirectException
+from synapse.module_api import ModuleApi
 from synapse.server import HomeServer
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests.test_utils import simple_async_mock
@@ -29,6 +31,7 @@ from tests.unittest import HomeserverTestCase, override_config
 # Check if we have the dependencies to run the tests.
 try:
     import saml2.config
+    import saml2.response
     from saml2.sigver import SigverError
 
     has_saml2 = True
@@ -56,31 +59,39 @@ class FakeAuthnResponse:
 
 
 class TestMappingProvider:
-    def __init__(self, config, module):
+    def __init__(self, config: None, module: ModuleApi):
         pass
 
     @staticmethod
-    def parse_config(config):
-        return
+    def parse_config(config: JsonDict) -> None:
+        return None
 
     @staticmethod
-    def get_saml_attributes(config):
+    def get_saml_attributes(config: None) -> Tuple[Set[str], Set[str]]:
         return {"uid"}, {"displayName"}
 
-    def get_remote_user_id(self, saml_response, client_redirect_url):
+    def get_remote_user_id(
+        self, saml_response: "saml2.response.AuthnResponse", client_redirect_url: str
+    ) -> str:
         return saml_response.ava["uid"]
 
     def saml_response_to_user_attributes(
-        self, saml_response, failures, client_redirect_url
-    ):
+        self,
+        saml_response: "saml2.response.AuthnResponse",
+        failures: int,
+        client_redirect_url: str,
+    ) -> dict:
         localpart = saml_response.ava["username"] + (str(failures) if failures else "")
         return {"mxid_localpart": localpart, "displayname": None}
 
 
 class TestRedirectMappingProvider(TestMappingProvider):
     def saml_response_to_user_attributes(
-        self, saml_response, failures, client_redirect_url
-    ):
+        self,
+        saml_response: "saml2.response.AuthnResponse",
+        failures: int,
+        client_redirect_url: str,
+    ) -> dict:
         raise RedirectException(b"https://custom-saml-redirect/")
 
 
@@ -347,7 +358,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
         )
 
 
-def _mock_request():
+def _mock_request() -> Mock:
     """Returns a mock which will stand in as a SynapseRequest"""
     mock = Mock(
         spec=[