diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/__init__.py | 2 | ||||
-rw-r--r-- | synapse/config/saml2_config.py | 33 | ||||
-rw-r--r-- | synapse/handlers/account_validity.py | 10 | ||||
-rw-r--r-- | synapse/handlers/saml_handler.py | 123 | ||||
-rw-r--r-- | synapse/http/server.py | 26 | ||||
-rw-r--r-- | synapse/rest/client/v1/login.py | 63 | ||||
-rw-r--r-- | synapse/rest/media/v1/preview_url_resource.py | 1 | ||||
-rw-r--r-- | synapse/rest/saml2/response_resource.py | 37 | ||||
-rw-r--r-- | synapse/server.py | 6 | ||||
-rw-r--r-- | synapse/storage/events.py | 9 | ||||
-rw-r--r-- | synapse/storage/registration.py | 13 | ||||
-rw-r--r-- | synapse/util/__init__.py | 8 | ||||
-rw-r--r-- | synapse/util/logcontext.py | 9 |
13 files changed, 271 insertions, 69 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index 119359be68..5fe8631973 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -35,4 +35,4 @@ try: except ImportError: pass -__version__ = "1.0.0" +__version__ = "1.1.0rc2" diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 872a1ba934..6a8161547a 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from synapse.python_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError @@ -25,6 +26,11 @@ class SAML2Config(Config): if not saml2_config or not saml2_config.get("enabled", True): return + try: + check_requirements("saml2") + except DependencyException as e: + raise ConfigError(e.message) + self.saml2_enabled = True import saml2.config @@ -37,6 +43,11 @@ class SAML2Config(Config): if config_path is not None: self.saml2_sp_config.load_file(config_path) + # session lifetime: in milliseconds + self.saml2_session_lifetime = self.parse_duration( + saml2_config.get("saml_session_lifetime", "5m") + ) + def _default_saml_config_dict(self): import saml2 @@ -72,6 +83,12 @@ class SAML2Config(Config): # so it is not normally necessary to specify them unless you need to # override them. # + # Once SAML support is enabled, a metadata file will be exposed at + # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to + # use to configure your SAML IdP with. Alternatively, you can manually configure + # the IdP to use an ACS location of + # https://<server>:<port>/_matrix/saml2/authn_response. + # #saml2_config: # sp_config: # # point this to the IdP's metadata. You can use either a local file or @@ -81,7 +98,15 @@ class SAML2Config(Config): # remote: # - url: https://our_idp/metadata.xml # - # # The rest of sp_config is just used to generate our metadata xml, and you + # # By default, the user has to go to our login page first. If you'd like to + # # allow IdP-initiated login, set 'allow_unsolicited: True' in a + # # 'service.sp' section: + # # + # #service: + # # sp: + # # allow_unsolicited: True + # + # # The examples below are just used to generate our metadata xml, and you # # may well not need it, depending on your setup. Alternatively you # # may need a whole lot more detail - see the pysaml2 docs! # @@ -104,6 +129,12 @@ class SAML2Config(Config): # # separate pysaml2 configuration file: # # # config_path: "%(config_dir_path)s/sp_conf.py" + # + # # the lifetime of a SAML session. This defines how long a user has to + # # complete the authentication process, if allow_unsolicited is unset. + # # The default is 5 minutes. + # # + # # saml_session_lifetime: 5m """ % { "config_dir_path": config_dir_path } diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 0719da3ab7..edb48054a0 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -22,6 +22,7 @@ from email.mime.text import MIMEText from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID from synapse.util import stringutils from synapse.util.logcontext import make_deferred_yieldable @@ -67,7 +68,14 @@ class AccountValidityHandler(object): ) # Check the renewal emails to send and send them every 30min. - self.clock.looping_call(self.send_renewal_emails, 30 * 60 * 1000) + def send_emails(): + # run as a background process to make sure that the database transactions + # have a logcontext to report to + return run_as_background_process( + "send_renewals", self.send_renewal_emails + ) + + self.clock.looping_call(send_emails, 30 * 60 * 1000) @defer.inlineCallbacks def send_renewal_emails(self): diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py new file mode 100644 index 0000000000..a1ce6929cf --- /dev/null +++ b/synapse/handlers/saml_handler.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import attr +import saml2 +from saml2.client import Saml2Client + +from synapse.api.errors import SynapseError +from synapse.http.servlet import parse_string +from synapse.rest.client.v1.login import SSOAuthHandler + +logger = logging.getLogger(__name__) + + +class SamlHandler: + def __init__(self, hs): + self._saml_client = Saml2Client(hs.config.saml2_sp_config) + self._sso_auth_handler = SSOAuthHandler(hs) + + # a map from saml session id to Saml2SessionData object + self._outstanding_requests_dict = {} + + self._clock = hs.get_clock() + self._saml2_session_lifetime = hs.config.saml2_session_lifetime + + def handle_redirect_request(self, client_redirect_url): + """Handle an incoming request to /login/sso/redirect + + Args: + client_redirect_url (bytes): the URL that we should redirect the + client to when everything is done + + Returns: + bytes: URL to redirect to + """ + reqid, info = self._saml_client.prepare_for_authenticate( + relay_state=client_redirect_url + ) + + now = self._clock.time_msec() + self._outstanding_requests_dict[reqid] = Saml2SessionData(creation_time=now) + + for key, value in info["headers"]: + if key == "Location": + return value + + # this shouldn't happen! + raise Exception("prepare_for_authenticate didn't return a Location header") + + def handle_saml_response(self, request): + """Handle an incoming request to /_matrix/saml2/authn_response + + Args: + request (SynapseRequest): the incoming request from the browser. We'll + respond to it with a redirect. + + Returns: + Deferred[none]: Completes once we have handled the request. + """ + resp_bytes = parse_string(request, "SAMLResponse", required=True) + relay_state = parse_string(request, "RelayState", required=True) + + # expire outstanding sessions before parse_authn_request_response checks + # the dict. + self.expire_sessions() + + try: + saml2_auth = self._saml_client.parse_authn_request_response( + resp_bytes, + saml2.BINDING_HTTP_POST, + outstanding=self._outstanding_requests_dict, + ) + except Exception as e: + logger.warning("Exception parsing SAML2 response: %s", e) + raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,)) + + if saml2_auth.not_signed: + logger.warning("SAML2 response was not signed") + raise SynapseError(400, "SAML2 response was not signed") + + if "uid" not in saml2_auth.ava: + logger.warning("SAML2 response lacks a 'uid' attestation") + raise SynapseError(400, "uid not in SAML2 response") + + self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) + + username = saml2_auth.ava["uid"][0] + displayName = saml2_auth.ava.get("displayName", [None])[0] + + return self._sso_auth_handler.on_successful_auth( + username, request, relay_state, user_display_name=displayName + ) + + def expire_sessions(self): + expire_before = self._clock.time_msec() - self._saml2_session_lifetime + to_expire = set() + for reqid, data in self._outstanding_requests_dict.items(): + if data.creation_time < expire_before: + to_expire.add(reqid) + for reqid in to_expire: + logger.debug("Expiring session id %s", reqid) + del self._outstanding_requests_dict[reqid] + + +@attr.s +class Saml2SessionData: + """Data we track about SAML2 sessions""" + + # time the session was created, in milliseconds + creation_time = attr.ib() diff --git a/synapse/http/server.py b/synapse/http/server.py index f067c163c1..d993161a3e 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -65,8 +65,8 @@ def wrap_json_request_handler(h): The handler method must have a signature of "handle_foo(self, request)", where "request" must be a SynapseRequest. - The handler must return a deferred. If the deferred succeeds we assume that - a response has been sent. If the deferred fails with a SynapseError we use + The handler must return a deferred or a coroutine. If the deferred succeeds + we assume that a response has been sent. If the deferred fails with a SynapseError we use it to send a JSON response with the appropriate HTTP reponse code. If the deferred fails with any other type of error we send a 500 reponse. """ @@ -353,16 +353,22 @@ class DirectServeResource(resource.Resource): """ Render the request, using an asynchronous render handler if it exists. """ - render_callback_name = "_async_render_" + request.method.decode("ascii") + async_render_callback_name = "_async_render_" + request.method.decode("ascii") - if hasattr(self, render_callback_name): - # Call the handler - callback = getattr(self, render_callback_name) - defer.ensureDeferred(callback(request)) + # Try and get the async renderer + callback = getattr(self, async_render_callback_name, None) - return NOT_DONE_YET - else: - super().render(request) + # No async renderer for this request method. + if not callback: + return super().render(request) + + resp = callback(request) + + # If it's a coroutine, turn it into a Deferred + if isinstance(resp, types.CoroutineType): + defer.ensureDeferred(resp) + + return NOT_DONE_YET def _options_handler(request): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index ede6bc8b1e..f961178235 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -86,6 +86,7 @@ class LoginRestServlet(RestServlet): self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm + self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() @@ -97,6 +98,9 @@ class LoginRestServlet(RestServlet): flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) + if self.saml2_enabled: + flows.append({"type": LoginRestServlet.SSO_TYPE}) + flows.append({"type": LoginRestServlet.TOKEN_TYPE}) if self.cas_enabled: flows.append({"type": LoginRestServlet.SSO_TYPE}) @@ -319,12 +323,12 @@ class LoginRestServlet(RestServlet): raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID(user, self.hs.hostname).to_string() + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if registered_user_id: - device_id = login_submission.get("device_id") - initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( registered_user_id, device_id, initial_display_name ) @@ -338,11 +342,8 @@ class LoginRestServlet(RestServlet): user_id, access_token = ( yield self.registration_handler.register(localpart=user) ) - - device_id = login_submission.get("device_id") - initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - registered_user_id, device_id, initial_display_name + user_id, device_id, initial_display_name ) result = { @@ -354,27 +355,49 @@ class LoginRestServlet(RestServlet): defer.returnValue(result) -class CasRedirectServlet(RestServlet): +class BaseSSORedirectServlet(RestServlet): + """Common base class for /login/sso/redirect impls""" + PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) + def on_GET(self, request): + args = request.args + if b"redirectUrl" not in args: + return 400, "Redirect URL not specified for SSO auth" + client_redirect_url = args[b"redirectUrl"][0] + sso_url = self.get_sso_url(client_redirect_url) + request.redirect(sso_url) + finish_request(request) + + def get_sso_url(self, client_redirect_url): + """Get the URL to redirect to, to perform SSO auth + + Args: + client_redirect_url (bytes): the URL that we should redirect the + client to when everything is done + + Returns: + bytes: URL to redirect to + """ + # to be implemented by subclasses + raise NotImplementedError() + + +class CasRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): super(CasRedirectServlet, self).__init__() self.cas_server_url = hs.config.cas_server_url.encode("ascii") self.cas_service_url = hs.config.cas_service_url.encode("ascii") - def on_GET(self, request): - args = request.args - if b"redirectUrl" not in args: - return (400, "Redirect URL not specified for CAS auth") + def get_sso_url(self, client_redirect_url): client_redirect_url_param = urllib.parse.urlencode( - {b"redirectUrl": args[b"redirectUrl"][0]} + {b"redirectUrl": client_redirect_url} ).encode("ascii") hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket" service_param = urllib.parse.urlencode( {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)} ).encode("ascii") - request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param)) - finish_request(request) + return b"%s/login?%s" % (self.cas_server_url, service_param) class CasTicketServlet(RestServlet): @@ -457,6 +480,16 @@ class CasTicketServlet(RestServlet): return user, attributes +class SAMLRedirectServlet(BaseSSORedirectServlet): + PATTERNS = client_patterns("/login/sso/redirect", v1=True) + + def __init__(self, hs): + self._saml_handler = hs.get_saml_handler() + + def get_sso_url(self, client_redirect_url): + return self._saml_handler.handle_redirect_request(client_redirect_url) + + class SSOAuthHandler(object): """ Utility class for Resources and Servlets which handle the response from a SSO @@ -532,3 +565,5 @@ def register_servlets(hs, http_server): if hs.config.cas_enabled: CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) + elif hs.config.saml2_enabled: + SAMLRedirectServlet(hs).register(http_server) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0337b64dc2..053346fb86 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -95,6 +95,7 @@ class PreviewUrlResource(DirectServeResource): ) def render_OPTIONS(self, request): + request.setHeader(b"Allow", b"OPTIONS, GET") return respond_with_json(request, 200, {}, send_cors=True) @wrap_json_request_handler diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py index 939c87306c..69ecc5e4b4 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/saml2/response_resource.py @@ -13,17 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import saml2 -from saml2.client import Saml2Client - -from synapse.api.errors import CodeMessageException from synapse.http.server import DirectServeResource, wrap_html_request_handler -from synapse.http.servlet import parse_string -from synapse.rest.client.v1.login import SSOAuthHandler - -logger = logging.getLogger(__name__) class SAML2ResponseResource(DirectServeResource): @@ -33,32 +24,8 @@ class SAML2ResponseResource(DirectServeResource): def __init__(self, hs): super().__init__() - - self._saml_client = Saml2Client(hs.config.saml2_sp_config) - self._sso_auth_handler = SSOAuthHandler(hs) + self._saml_handler = hs.get_saml_handler() @wrap_html_request_handler async def _async_render_POST(self, request): - resp_bytes = parse_string(request, "SAMLResponse", required=True) - relay_state = parse_string(request, "RelayState", required=True) - - try: - saml2_auth = self._saml_client.parse_authn_request_response( - resp_bytes, saml2.BINDING_HTTP_POST - ) - except Exception as e: - logger.warning("Exception parsing SAML2 response", exc_info=1) - raise CodeMessageException(400, "Unable to parse SAML2 response: %s" % (e,)) - - if saml2_auth.not_signed: - raise CodeMessageException(400, "SAML2 response was not signed") - - if "uid" not in saml2_auth.ava: - raise CodeMessageException(400, "uid not in SAML2 response") - - username = saml2_auth.ava["uid"][0] - - displayName = saml2_auth.ava.get("displayName", [None])[0] - return self._sso_auth_handler.on_successful_auth( - username, request, relay_state, user_display_name=displayName - ) + return await self._saml_handler.handle_saml_response(request) diff --git a/synapse/server.py b/synapse/server.py index a9592c396c..9e28dba2b1 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -194,6 +194,7 @@ class HomeServer(object): "sendmail", "registration_handler", "account_validity_handler", + "saml_handler", "event_client_serializer", ] @@ -524,6 +525,11 @@ class HomeServer(object): def build_account_validity_handler(self): return AccountValidityHandler(self) + def build_saml_handler(self): + from synapse.handlers.saml_handler import SamlHandler + + return SamlHandler(self) + def build_event_client_serializer(self): return EventClientSerializer(self) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index fefba39ea1..86f8485704 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -253,7 +253,14 @@ class EventsStore( ) # Read the extrems every 60 minutes - hs.get_clock().looping_call(self._read_forward_extremities, 60 * 60 * 1000) + def read_forward_extremities(): + # run as a background process to make sure that the database transactions + # have a logcontext to report to + return run_as_background_process( + "read_forward_extremities", self._read_forward_extremities + ) + + hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) @defer.inlineCallbacks def _read_forward_extremities(self): diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 983ce13291..13a3d5208b 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -25,6 +25,7 @@ from twisted.internet import defer from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, ThreepidValidationError +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage import background_updates from synapse.storage._base import SQLBaseStore from synapse.types import UserID @@ -619,9 +620,15 @@ class RegistrationStore( ) # Create a background job for culling expired 3PID validity tokens - hs.get_clock().looping_call( - self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS - ) + def start_cull(): + # run as a background process to make sure that the database transactions + # have a logcontext to report to + return run_as_background_process( + "cull_expired_threepid_validation_tokens", + self.cull_expired_threepid_validation_tokens, + ) + + hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) @defer.inlineCallbacks def _backgroud_update_set_deactivated_flag(self, progress, batch_size): diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index dcc747cac1..954e32fb2a 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -62,7 +62,10 @@ class Clock(object): def looping_call(self, f, msec): """Call a function repeatedly. - Waits `msec` initially before calling `f` for the first time. + Waits `msec` initially before calling `f` for the first time. + + Note that the function will be called with no logcontext, so if it is anything + other than trivial, you probably want to wrap it in run_as_background_process. Args: f(function): The function to call repeatedly. @@ -77,6 +80,9 @@ class Clock(object): def call_later(self, delay, callback, *args, **kwargs): """Call something later + Note that the function will be called with no logcontext, so if it is anything + other than trivial, you probably want to wrap it in run_as_background_process. + Args: delay(float): How long to wait in seconds. callback(function): Function to call diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 6b0d2deea0..9e1b537804 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -24,6 +24,7 @@ See doc/log_contexts.rst for details on how this works. import logging import threading +import types from twisted.internet import defer, threads @@ -528,8 +529,9 @@ def run_in_background(f, *args, **kwargs): return from the function, and that the sentinel context is set once the deferred returned by the function completes. - Useful for wrapping functions that return a deferred which you don't yield - on (for instance because you want to pass it to deferred.gatherResults()). + Useful for wrapping functions that return a deferred or coroutine, which you don't + yield or await on (for instance because you want to pass it to + deferred.gatherResults()). Note that if you completely discard the result, you should make sure that `f` doesn't raise any deferred exceptions, otherwise a scary-looking @@ -544,6 +546,9 @@ def run_in_background(f, *args, **kwargs): # by synchronous exceptions, so let's turn them into Failures. return defer.fail() + if isinstance(res, types.CoroutineType): + res = defer.ensureDeferred(res) + if not isinstance(res, defer.Deferred): return res |