diff options
-rw-r--r-- | synapse/config/saml2_config.py | 20 | ||||
-rw-r--r-- | synapse/handlers/saml2_handler.py | 39 |
2 files changed, 56 insertions, 3 deletions
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 463b5fdd68..965a97837f 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -27,7 +27,7 @@ class SAML2Config(Config): return try: - check_requirements('saml2') + check_requirements("saml2") except DependencyException as e: raise ConfigError(e.message) @@ -43,6 +43,11 @@ class SAML2Config(Config): if config_path is not None: self.saml2_sp_config.load_file(config_path) + # session lifetime: in milliseconds + self.saml2_session_lifetime = self.parse_duration( + saml2_config.get("saml_session_lifetime", "5m") + ) + def _default_saml_config_dict(self): import saml2 @@ -87,6 +92,13 @@ class SAML2Config(Config): # remote: # - url: https://our_idp/metadata.xml # + # # By default, the user has to go to our login page first. If you'd like to + # # allow IdP-initiated login, set 'allow_unsolicited: True' in an 'sp' + # # section: + # # + # #sp: + # # allow_unsolicited: True + # # # # The rest of sp_config is just used to generate our metadata xml, and you # # may well not need it, depending on your setup. Alternatively you # # may need a whole lot more detail - see the pysaml2 docs! @@ -110,6 +122,12 @@ class SAML2Config(Config): # # separate pysaml2 configuration file: # # # config_path: "%(config_dir_path)s/sp_conf.py" + # + # # the lifetime of a SAML session. This defines how long a user has to + # # complete the authentication process, if allow_unsolicited is unset. + # # The default is 5 minutes. + # # + # # saml_session_lifetime: 5m """ % { "config_dir_path": config_dir_path } diff --git a/synapse/handlers/saml2_handler.py b/synapse/handlers/saml2_handler.py index 880e6a625f..b06d3f172e 100644 --- a/synapse/handlers/saml2_handler.py +++ b/synapse/handlers/saml2_handler.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import attr import saml2 from saml2.client import Saml2Client @@ -29,6 +30,12 @@ class Saml2Handler: self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._sso_auth_handler = SSOAuthHandler(hs) + # a map from saml session id to Saml2SessionData object + self._outstanding_requests_dict = {} + + self._clock = hs.get_clock() + self._saml2_session_lifetime = hs.config.saml2_session_lifetime + def handle_redirect_request(self, client_redirect_url): """Handle an incoming request to /login/sso/redirect @@ -43,6 +50,9 @@ class Saml2Handler: relay_state=client_redirect_url ) + now = self._clock.time_msec() + self._outstanding_requests_dict[reqid] = Saml2SessionData(creation_time=now) + for key, value in info["headers"]: if key == "Location": return value @@ -63,9 +73,15 @@ class Saml2Handler: resp_bytes = parse_string(request, "SAMLResponse", required=True) relay_state = parse_string(request, "RelayState", required=True) + # expire outstanding sessions before parse_authn_request_response checks + # the dict. + self.expire_sessions() + try: saml2_auth = self._saml_client.parse_authn_request_response( - resp_bytes, saml2.BINDING_HTTP_POST + resp_bytes, + saml2.BINDING_HTTP_POST, + outstanding=self._outstanding_requests_dict, ) except Exception as e: logger.warning("Exception parsing SAML2 response", exc_info=1) @@ -77,10 +93,29 @@ class Saml2Handler: if "uid" not in saml2_auth.ava: raise CodeMessageException(400, "uid not in SAML2 response") - username = saml2_auth.ava["uid"][0] + self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) + username = saml2_auth.ava["uid"][0] displayName = saml2_auth.ava.get("displayName", [None])[0] return self._sso_auth_handler.on_successful_auth( username, request, relay_state, user_display_name=displayName ) + + def expire_sessions(self): + expire_before = self._clock.time_msec() - self._saml2_session_lifetime + to_expire = set() + for reqid, data in self._outstanding_requests_dict.items(): + if data.creation_time < expire_before: + to_expire.add(reqid) + for reqid in to_expire: + logger.debug("Expiring session id %s", reqid) + del self._outstanding_requests_dict[reqid] + + +@attr.s +class Saml2SessionData: + """Data we track about SAML2 sessions""" + + # time the session was created, in milliseconds + creation_time = attr.ib() |