diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 71a89f09c7..1924636c4d 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -57,13 +57,10 @@ class EventStreamHandler(BaseHandler):
timeout=0,
as_client_event=True,
affect_presence=True,
- only_keys=None,
room_id=None,
is_guest=False,
):
"""Fetches the events stream for a given user.
-
- If `only_keys` is not None, events from keys will be sent down.
"""
if room_id:
@@ -93,7 +90,6 @@ class EventStreamHandler(BaseHandler):
auth_user,
pagin_config,
timeout,
- only_keys=only_keys,
is_guest=is_guest,
explicit_room_id=room_id,
)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 2d506dc1f2..c1fcb98454 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -14,15 +14,16 @@
# limitations under the License.
import logging
import re
-from typing import Callable, Dict, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple
import attr
import saml2
import saml2.response
from saml2.client import Saml2Client
-from synapse.api.errors import SynapseError
+from synapse.api.errors import AuthError, SynapseError
from synapse.config import ConfigError
+from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
@@ -34,6 +35,9 @@ from synapse.types import (
from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq
+if TYPE_CHECKING:
+ import synapse.server
+
logger = logging.getLogger(__name__)
@@ -49,7 +53,7 @@ class Saml2SessionData:
class SamlHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
@@ -62,6 +66,7 @@ class SamlHandler:
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
)
+ self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
# plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
@@ -73,7 +78,7 @@ class SamlHandler:
self._auth_provider_id = "saml"
# a map from saml session id to Saml2SessionData object
- self._outstanding_requests_dict = {}
+ self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
@@ -165,11 +170,18 @@ class SamlHandler:
saml2.BINDING_HTTP_POST,
outstanding=self._outstanding_requests_dict,
)
+ except saml2.response.UnsolicitedResponse as e:
+ # the pysaml2 library helpfully logs an ERROR here, but neglects to log
+ # the session ID. I don't really want to put the full text of the exception
+ # in the (user-visible) exception message, so let's log the exception here
+ # so we can track down the session IDs later.
+ logger.warning(str(e))
+ raise SynapseError(400, "Unexpected SAML2 login.")
except Exception as e:
- raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,))
+ raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,))
if saml2_auth.not_signed:
- raise SynapseError(400, "SAML2 response was not signed")
+ raise SynapseError(400, "SAML2 response was not signed.")
logger.debug("SAML2 response: %s", saml2_auth.origxml)
for assertion in saml2_auth.assertions:
@@ -188,6 +200,9 @@ class SamlHandler:
saml2_auth.in_response_to, None
)
+ for requirement in self._saml2_attribute_requirements:
+ _check_attribute_requirement(saml2_auth.ava, requirement)
+
remote_user_id = self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url
)
@@ -294,6 +309,21 @@ class SamlHandler:
del self._outstanding_requests_dict[reqid]
+def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
+ values = ava.get(req.attribute, [])
+ for v in values:
+ if v == req.value:
+ return
+
+ logger.info(
+ "SAML2 attribute %s did not match required value '%s' (was '%s')",
+ req.attribute,
+ req.value,
+ values,
+ )
+ raise AuthError(403, "You are not authorized to log in here.")
+
+
DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
)
|