diff options
author | Erik Johnston <erik@matrix.org> | 2020-08-11 22:03:14 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2020-08-11 22:03:14 +0100 |
commit | fdb46b5442c63212a52e2296491a23f1935f9929 (patch) | |
tree | c46d0940415f96e3c0383c35f1d39b7462d5c20c /synapse/handlers/saml_handler.py | |
parent | Add comment explaining cast (diff) | |
parent | Auto set logging filter (#8051) (diff) | |
download | synapse-fdb46b5442c63212a52e2296491a23f1935f9929.tar.xz |
Merge remote-tracking branch 'origin/develop' into erikj/type_server
Diffstat (limited to 'synapse/handlers/saml_handler.py')
-rw-r--r-- | synapse/handlers/saml_handler.py | 42 |
1 files changed, 36 insertions, 6 deletions
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 2d506dc1f2..c1fcb98454 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -14,15 +14,16 @@ # limitations under the License. import logging import re -from typing import Callable, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple import attr import saml2 import saml2.response from saml2.client import Saml2Client -from synapse.api.errors import SynapseError +from synapse.api.errors import AuthError, SynapseError from synapse.config import ConfigError +from synapse.config.saml2_config import SamlAttributeRequirement from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.module_api import ModuleApi @@ -34,6 +35,9 @@ from synapse.types import ( from synapse.util.async_helpers import Linearizer from synapse.util.iterutils import chunk_seq +if TYPE_CHECKING: + import synapse.server + logger = logging.getLogger(__name__) @@ -49,7 +53,7 @@ class Saml2SessionData: class SamlHandler: - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() @@ -62,6 +66,7 @@ class SamlHandler: self._grandfathered_mxid_source_attribute = ( hs.config.saml2_grandfathered_mxid_source_attribute ) + self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements # plugin to do custom mapping from saml response to mxid self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( @@ -73,7 +78,7 @@ class SamlHandler: self._auth_provider_id = "saml" # a map from saml session id to Saml2SessionData object - self._outstanding_requests_dict = {} + self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] # a lock on the mappings self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) @@ -165,11 +170,18 @@ class SamlHandler: saml2.BINDING_HTTP_POST, outstanding=self._outstanding_requests_dict, ) + except saml2.response.UnsolicitedResponse as e: + # the pysaml2 library helpfully logs an ERROR here, but neglects to log + # the session ID. I don't really want to put the full text of the exception + # in the (user-visible) exception message, so let's log the exception here + # so we can track down the session IDs later. + logger.warning(str(e)) + raise SynapseError(400, "Unexpected SAML2 login.") except Exception as e: - raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,)) + raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,)) if saml2_auth.not_signed: - raise SynapseError(400, "SAML2 response was not signed") + raise SynapseError(400, "SAML2 response was not signed.") logger.debug("SAML2 response: %s", saml2_auth.origxml) for assertion in saml2_auth.assertions: @@ -188,6 +200,9 @@ class SamlHandler: saml2_auth.in_response_to, None ) + for requirement in self._saml2_attribute_requirements: + _check_attribute_requirement(saml2_auth.ava, requirement) + remote_user_id = self._user_mapping_provider.get_remote_user_id( saml2_auth, client_redirect_url ) @@ -294,6 +309,21 @@ class SamlHandler: del self._outstanding_requests_dict[reqid] +def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement): + values = ava.get(req.attribute, []) + for v in values: + if v == req.value: + return + + logger.info( + "SAML2 attribute %s did not match required value '%s' (was '%s')", + req.attribute, + req.value, + values, + ) + raise AuthError(403, "You are not authorized to log in here.") + + DOT_REPLACE_PATTERN = re.compile( ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) ) |