diff options
Diffstat (limited to 'synapse/rest')
-rw-r--r-- | synapse/rest/client/transactions.py | 2 | ||||
-rw-r--r-- | synapse/rest/client/v1/login.py | 63 | ||||
-rw-r--r-- | synapse/rest/media/v1/_base.py | 6 | ||||
-rw-r--r-- | synapse/rest/media/v1/media_repository.py | 12 | ||||
-rw-r--r-- | synapse/rest/media/v1/media_storage.py | 5 | ||||
-rw-r--r-- | synapse/rest/media/v1/preview_url_resource.py | 3 | ||||
-rw-r--r-- | synapse/rest/media/v1/storage_provider.py | 5 | ||||
-rw-r--r-- | synapse/rest/saml2/response_resource.py | 37 |
8 files changed, 66 insertions, 67 deletions
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 36404b797d..6da71dc46f 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -17,8 +17,8 @@ to ensure idempotency when performing PUTs using the REST API.""" import logging +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.util.async_helpers import ObservableDeferred -from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index ede6bc8b1e..f961178235 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -86,6 +86,7 @@ class LoginRestServlet(RestServlet): self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm + self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() @@ -97,6 +98,9 @@ class LoginRestServlet(RestServlet): flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) + if self.saml2_enabled: + flows.append({"type": LoginRestServlet.SSO_TYPE}) + flows.append({"type": LoginRestServlet.TOKEN_TYPE}) if self.cas_enabled: flows.append({"type": LoginRestServlet.SSO_TYPE}) @@ -319,12 +323,12 @@ class LoginRestServlet(RestServlet): raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID(user, self.hs.hostname).to_string() + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if registered_user_id: - device_id = login_submission.get("device_id") - initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( registered_user_id, device_id, initial_display_name ) @@ -338,11 +342,8 @@ class LoginRestServlet(RestServlet): user_id, access_token = ( yield self.registration_handler.register(localpart=user) ) - - device_id = login_submission.get("device_id") - initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - registered_user_id, device_id, initial_display_name + user_id, device_id, initial_display_name ) result = { @@ -354,27 +355,49 @@ class LoginRestServlet(RestServlet): defer.returnValue(result) -class CasRedirectServlet(RestServlet): +class BaseSSORedirectServlet(RestServlet): + """Common base class for /login/sso/redirect impls""" + PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) + def on_GET(self, request): + args = request.args + if b"redirectUrl" not in args: + return 400, "Redirect URL not specified for SSO auth" + client_redirect_url = args[b"redirectUrl"][0] + sso_url = self.get_sso_url(client_redirect_url) + request.redirect(sso_url) + finish_request(request) + + def get_sso_url(self, client_redirect_url): + """Get the URL to redirect to, to perform SSO auth + + Args: + client_redirect_url (bytes): the URL that we should redirect the + client to when everything is done + + Returns: + bytes: URL to redirect to + """ + # to be implemented by subclasses + raise NotImplementedError() + + +class CasRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): super(CasRedirectServlet, self).__init__() self.cas_server_url = hs.config.cas_server_url.encode("ascii") self.cas_service_url = hs.config.cas_service_url.encode("ascii") - def on_GET(self, request): - args = request.args - if b"redirectUrl" not in args: - return (400, "Redirect URL not specified for CAS auth") + def get_sso_url(self, client_redirect_url): client_redirect_url_param = urllib.parse.urlencode( - {b"redirectUrl": args[b"redirectUrl"][0]} + {b"redirectUrl": client_redirect_url} ).encode("ascii") hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket" service_param = urllib.parse.urlencode( {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)} ).encode("ascii") - request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param)) - finish_request(request) + return b"%s/login?%s" % (self.cas_server_url, service_param) class CasTicketServlet(RestServlet): @@ -457,6 +480,16 @@ class CasTicketServlet(RestServlet): return user, attributes +class SAMLRedirectServlet(BaseSSORedirectServlet): + PATTERNS = client_patterns("/login/sso/redirect", v1=True) + + def __init__(self, hs): + self._saml_handler = hs.get_saml_handler() + + def get_sso_url(self, client_redirect_url): + return self._saml_handler.handle_redirect_request(client_redirect_url) + + class SSOAuthHandler(object): """ Utility class for Resources and Servlets which handle the response from a SSO @@ -532,3 +565,5 @@ def register_servlets(hs, http_server): if hs.config.cas_enabled: CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) + elif hs.config.saml2_enabled: + SAMLRedirectServlet(hs).register(http_server) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 3318638d3e..5fefee4dde 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -25,7 +25,7 @@ from twisted.protocols.basic import FileSender from synapse.api.errors import Codes, SynapseError, cs_error from synapse.http.server import finish_request, respond_with_json -from synapse.util import logcontext +from synapse.logging.context import make_deferred_yieldable from synapse.util.stringutils import is_ascii logger = logging.getLogger(__name__) @@ -75,9 +75,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam add_file_headers(request, media_type, file_size, upload_name) with open(file_path, "rb") as f: - yield logcontext.make_deferred_yieldable( - FileSender().beginFileTransfer(f, request) - ) + yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) finish_request(request) else: diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index df3d985a38..65afffbb42 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -33,8 +33,8 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util import logcontext from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import random_string @@ -463,7 +463,7 @@ class MediaRepository(object): ) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield logcontext.defer_to_thread( + t_byte_source = yield defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, @@ -511,7 +511,7 @@ class MediaRepository(object): ) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield logcontext.defer_to_thread( + t_byte_source = yield defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, @@ -596,7 +596,7 @@ class MediaRepository(object): return if thumbnailer.transpose_method is not None: - m_width, m_height = yield logcontext.defer_to_thread( + m_width, m_height = yield defer_to_thread( self.hs.get_reactor(), thumbnailer.transpose ) @@ -616,11 +616,11 @@ class MediaRepository(object): for (t_width, t_height, t_type), t_method in iteritems(thumbnails): # Generate the thumbnail if t_method == "crop": - t_byte_source = yield logcontext.defer_to_thread( + t_byte_source = yield defer_to_thread( self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type ) elif t_method == "scale": - t_byte_source = yield logcontext.defer_to_thread( + t_byte_source = yield defer_to_thread( self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type ) else: diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index eff86836fb..25e5ac2848 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -24,9 +24,8 @@ import six from twisted.internet import defer from twisted.protocols.basic import FileSender -from synapse.util import logcontext +from synapse.logging.context import defer_to_thread, make_deferred_yieldable from synapse.util.file_consumer import BackgroundFileConsumer -from synapse.util.logcontext import make_deferred_yieldable from ._base import Responder @@ -65,7 +64,7 @@ class MediaStorage(object): with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository - yield logcontext.defer_to_thread( + yield defer_to_thread( self.hs.get_reactor(), _write_file_synchronously, source, f ) yield finish_cb() diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0337b64dc2..5871737bfd 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -42,11 +42,11 @@ from synapse.http.server import ( wrap_json_request_handler, ) from synapse.http.servlet import parse_integer, parse_string +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.stringutils import random_string from ._base import FileInfo @@ -95,6 +95,7 @@ class PreviewUrlResource(DirectServeResource): ) def render_OPTIONS(self, request): + request.setHeader(b"Allow", b"OPTIONS, GET") return respond_with_json(request, 200, {}, send_cors=True) @wrap_json_request_handler diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 359b45ebfc..e8f559acc1 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -20,8 +20,7 @@ import shutil from twisted.internet import defer from synapse.config._base import Config -from synapse.util import logcontext -from synapse.util.logcontext import run_in_background +from synapse.logging.context import defer_to_thread, run_in_background from .media_storage import FileResponder @@ -125,7 +124,7 @@ class FileStorageProviderBackend(StorageProvider): if not os.path.exists(dirname): os.makedirs(dirname) - return logcontext.defer_to_thread( + return defer_to_thread( self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py index 939c87306c..69ecc5e4b4 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/saml2/response_resource.py @@ -13,17 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import saml2 -from saml2.client import Saml2Client - -from synapse.api.errors import CodeMessageException from synapse.http.server import DirectServeResource, wrap_html_request_handler -from synapse.http.servlet import parse_string -from synapse.rest.client.v1.login import SSOAuthHandler - -logger = logging.getLogger(__name__) class SAML2ResponseResource(DirectServeResource): @@ -33,32 +24,8 @@ class SAML2ResponseResource(DirectServeResource): def __init__(self, hs): super().__init__() - - self._saml_client = Saml2Client(hs.config.saml2_sp_config) - self._sso_auth_handler = SSOAuthHandler(hs) + self._saml_handler = hs.get_saml_handler() @wrap_html_request_handler async def _async_render_POST(self, request): - resp_bytes = parse_string(request, "SAMLResponse", required=True) - relay_state = parse_string(request, "RelayState", required=True) - - try: - saml2_auth = self._saml_client.parse_authn_request_response( - resp_bytes, saml2.BINDING_HTTP_POST - ) - except Exception as e: - logger.warning("Exception parsing SAML2 response", exc_info=1) - raise CodeMessageException(400, "Unable to parse SAML2 response: %s" % (e,)) - - if saml2_auth.not_signed: - raise CodeMessageException(400, "SAML2 response was not signed") - - if "uid" not in saml2_auth.ava: - raise CodeMessageException(400, "uid not in SAML2 response") - - username = saml2_auth.ava["uid"][0] - - displayName = saml2_auth.ava.get("displayName", [None])[0] - return self._sso_auth_handler.on_successful_auth( - username, request, relay_state, user_display_name=displayName - ) + return await self._saml_handler.handle_saml_response(request) |