diff options
author | Erik Johnston <erik@matrix.org> | 2019-07-15 14:13:22 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2019-07-15 14:13:22 +0100 |
commit | e8c53b07f2fa5cdd671841cb6feed0f6b3f8d073 (patch) | |
tree | a8105b0f3a9efd467f10500e933125bf203ab42e /synapse/rest | |
parent | Use set_defaults(func=) style (diff) | |
parent | Return a different error from Invalid Password when a user is deactivated (#5... (diff) | |
download | synapse-e8c53b07f2fa5cdd671841cb6feed0f6b3f8d073.tar.xz |
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/admin_api_cmd
Diffstat (limited to 'synapse/rest')
-rw-r--r-- | synapse/rest/admin/__init__.py | 3 | ||||
-rw-r--r-- | synapse/rest/client/transactions.py | 2 | ||||
-rw-r--r-- | synapse/rest/client/v1/directory.py | 10 | ||||
-rw-r--r-- | synapse/rest/client/v1/login.py | 110 | ||||
-rw-r--r-- | synapse/rest/client/v1/room.py | 9 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/register.py | 11 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/relations.py | 20 | ||||
-rw-r--r-- | synapse/rest/media/v1/_base.py | 6 | ||||
-rw-r--r-- | synapse/rest/media/v1/media_repository.py | 12 | ||||
-rw-r--r-- | synapse/rest/media/v1/media_storage.py | 5 | ||||
-rw-r--r-- | synapse/rest/media/v1/preview_url_resource.py | 3 | ||||
-rw-r--r-- | synapse/rest/media/v1/storage_provider.py | 7 | ||||
-rw-r--r-- | synapse/rest/saml2/response_resource.py | 37 |
13 files changed, 110 insertions, 125 deletions
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 9843a902c6..6888ae5590 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -219,11 +219,10 @@ class UserRegisterServlet(RestServlet): register = RegisterRestServlet(self.hs) - (user_id, _) = yield register.registration_handler.register( + user_id = yield register.registration_handler.register_user( localpart=body["username"].lower(), password=body["password"], admin=bool(admin), - generate_token=False, user_type=user_type, ) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 36404b797d..6da71dc46f 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -17,8 +17,8 @@ to ensure idempotency when performing PUTs using the REST API.""" import logging +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.util.async_helpers import ObservableDeferred -from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index dd0d38ea5c..57542c2b4b 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -18,7 +18,13 @@ import logging from twisted.internet import defer -from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + InvalidClientCredentialsError, + NotFoundError, + SynapseError, +) from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.rest.client.v2_alpha._base import client_patterns from synapse.types import RoomAlias @@ -97,7 +103,7 @@ class ClientDirectoryServer(RestServlet): room_alias.to_string(), ) defer.returnValue((200, {})) - except AuthError: + except InvalidClientCredentialsError: # fallback to default user behaviour if they aren't an AS pass diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index ede6bc8b1e..0d05945f0a 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -86,6 +86,7 @@ class LoginRestServlet(RestServlet): self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm + self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() @@ -97,6 +98,9 @@ class LoginRestServlet(RestServlet): flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) + if self.saml2_enabled: + flows.append({"type": LoginRestServlet.SSO_TYPE}) + flows.append({"type": LoginRestServlet.TOKEN_TYPE}) if self.cas_enabled: flows.append({"type": LoginRestServlet.SSO_TYPE}) @@ -279,19 +283,7 @@ class LoginRestServlet(RestServlet): yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) - device_id = login_submission.get("device_id") - initial_display_name = login_submission.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name - ) - - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "home_server": self.hs.hostname, - "device_id": device_id, - } - + result = yield self._register_device_with_callback(user_id, login_submission) defer.returnValue(result) @defer.inlineCallbacks @@ -320,61 +312,61 @@ class LoginRestServlet(RestServlet): user_id = UserID(user, self.hs.hostname).to_string() - auth_handler = self.auth_handler - registered_user_id = yield auth_handler.check_user_exists(user_id) - if registered_user_id: - device_id = login_submission.get("device_id") - initial_display_name = login_submission.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( - registered_user_id, device_id, initial_display_name + registered_user_id = yield self.auth_handler.check_user_exists(user_id) + if not registered_user_id: + registered_user_id = yield self.registration_handler.register_user( + localpart=user ) - result = { - "user_id": registered_user_id, - "access_token": access_token, - "home_server": self.hs.hostname, - } - else: - user_id, access_token = ( - yield self.registration_handler.register(localpart=user) - ) + result = yield self._register_device_with_callback( + registered_user_id, login_submission + ) + defer.returnValue(result) - device_id = login_submission.get("device_id") - initial_display_name = login_submission.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( - registered_user_id, device_id, initial_display_name - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "home_server": self.hs.hostname, - } +class BaseSSORedirectServlet(RestServlet): + """Common base class for /login/sso/redirect impls""" - defer.returnValue(result) + PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) + def on_GET(self, request): + args = request.args + if b"redirectUrl" not in args: + return 400, "Redirect URL not specified for SSO auth" + client_redirect_url = args[b"redirectUrl"][0] + sso_url = self.get_sso_url(client_redirect_url) + request.redirect(sso_url) + finish_request(request) -class CasRedirectServlet(RestServlet): - PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) + def get_sso_url(self, client_redirect_url): + """Get the URL to redirect to, to perform SSO auth + + Args: + client_redirect_url (bytes): the URL that we should redirect the + client to when everything is done + + Returns: + bytes: URL to redirect to + """ + # to be implemented by subclasses + raise NotImplementedError() + +class CasRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): super(CasRedirectServlet, self).__init__() self.cas_server_url = hs.config.cas_server_url.encode("ascii") self.cas_service_url = hs.config.cas_service_url.encode("ascii") - def on_GET(self, request): - args = request.args - if b"redirectUrl" not in args: - return (400, "Redirect URL not specified for CAS auth") + def get_sso_url(self, client_redirect_url): client_redirect_url_param = urllib.parse.urlencode( - {b"redirectUrl": args[b"redirectUrl"][0]} + {b"redirectUrl": client_redirect_url} ).encode("ascii") hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket" service_param = urllib.parse.urlencode( {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)} ).encode("ascii") - request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param)) - finish_request(request) + return b"%s/login?%s" % (self.cas_server_url, service_param) class CasTicketServlet(RestServlet): @@ -457,6 +449,16 @@ class CasTicketServlet(RestServlet): return user, attributes +class SAMLRedirectServlet(BaseSSORedirectServlet): + PATTERNS = client_patterns("/login/sso/redirect", v1=True) + + def __init__(self, hs): + self._saml_handler = hs.get_saml_handler() + + def get_sso_url(self, client_redirect_url): + return self._saml_handler.handle_redirect_request(client_redirect_url) + + class SSOAuthHandler(object): """ Utility class for Resources and Servlets which handle the response from a SSO @@ -501,12 +503,8 @@ class SSOAuthHandler(object): user_id = UserID(localpart, self._hostname).to_string() registered_user_id = yield self._auth_handler.check_user_exists(user_id) if not registered_user_id: - registered_user_id, _ = ( - yield self._registration_handler.register( - localpart=localpart, - generate_token=False, - default_display_name=user_display_name, - ) + registered_user_id = yield self._registration_handler.register_user( + localpart=localpart, default_display_name=user_display_name ) login_token = self._macaroon_gen.generate_short_term_login_token( @@ -532,3 +530,5 @@ def register_servlets(hs, http_server): if hs.config.cas_enabled: CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) + elif hs.config.saml2_enabled: + SAMLRedirectServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index cca7e45ddb..7709c2d705 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -24,7 +24,12 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + InvalidClientCredentialsError, + SynapseError, +) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( @@ -307,7 +312,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): try: yield self.auth.get_user_by_req(request, allow_guest=True) - except AuthError as e: + except InvalidClientCredentialsError as e: # Option to allow servers to require auth when accessing # /publicRooms via CS API. This is especially helpful in private # federations. diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 5c120e4dd5..f327999e59 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -464,11 +464,10 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_IN_USE, ) - (registered_user_id, _) = yield self.registration_handler.register( + registered_user_id = yield self.registration_handler.register_user( localpart=desired_username, password=new_password, guest_access_token=guest_access_token, - generate_token=False, threepid=threepid, address=client_addr, ) @@ -542,8 +541,8 @@ class RegisterRestServlet(RestServlet): if not compare_digest(want_mac, got_mac): raise SynapseError(403, "HMAC incorrect") - (user_id, _) = yield self.registration_handler.register( - localpart=username, password=password, generate_token=False + user_id = yield self.registration_handler.register_user( + localpart=username, password=password ) result = yield self._create_registration_details(user_id, body) @@ -577,8 +576,8 @@ class RegisterRestServlet(RestServlet): def _do_guest_registration(self, params, address=None): if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") - user_id, _ = yield self.registration_handler.register( - generate_token=False, make_guest=True, address=address + user_id = yield self.registration_handler.register_user( + make_guest=True, address=address ) # we don't allow guests to specify their own device_id, because diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 8e362782cc..7ce485b471 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -145,9 +145,9 @@ class RelationPaginationServlet(RestServlet): room_id, requester.user.to_string() ) - # This checks that a) the event exists and b) the user is allowed to - # view it. - yield self.event_handler.get_event(requester.user, room_id, parent_id) + # This gets the original event and checks that a) the event exists and + # b) the user is allowed to view it. + event = yield self.event_handler.get_event(requester.user, room_id, parent_id) limit = parse_integer(request, "limit", default=5) from_token = parse_string(request, "from") @@ -173,10 +173,22 @@ class RelationPaginationServlet(RestServlet): ) now = self.clock.time_msec() - events = yield self._event_serializer.serialize_events(events, now) + # We set bundle_aggregations to False when retrieving the original + # event because we want the content before relations were applied to + # it. + original_event = yield self._event_serializer.serialize_event( + event, now, bundle_aggregations=False + ) + # Similarly, we don't allow relations to be applied to relations, so we + # return the original relations without any aggregations on top of them + # here. + events = yield self._event_serializer.serialize_events( + events, now, bundle_aggregations=False + ) return_value = result.to_dict() return_value["chunk"] = events + return_value["original_event"] = original_event defer.returnValue((200, return_value)) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 3318638d3e..5fefee4dde 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -25,7 +25,7 @@ from twisted.protocols.basic import FileSender from synapse.api.errors import Codes, SynapseError, cs_error from synapse.http.server import finish_request, respond_with_json -from synapse.util import logcontext +from synapse.logging.context import make_deferred_yieldable from synapse.util.stringutils import is_ascii logger = logging.getLogger(__name__) @@ -75,9 +75,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam add_file_headers(request, media_type, file_size, upload_name) with open(file_path, "rb") as f: - yield logcontext.make_deferred_yieldable( - FileSender().beginFileTransfer(f, request) - ) + yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) finish_request(request) else: diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index df3d985a38..65afffbb42 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -33,8 +33,8 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util import logcontext from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import random_string @@ -463,7 +463,7 @@ class MediaRepository(object): ) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield logcontext.defer_to_thread( + t_byte_source = yield defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, @@ -511,7 +511,7 @@ class MediaRepository(object): ) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield logcontext.defer_to_thread( + t_byte_source = yield defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, @@ -596,7 +596,7 @@ class MediaRepository(object): return if thumbnailer.transpose_method is not None: - m_width, m_height = yield logcontext.defer_to_thread( + m_width, m_height = yield defer_to_thread( self.hs.get_reactor(), thumbnailer.transpose ) @@ -616,11 +616,11 @@ class MediaRepository(object): for (t_width, t_height, t_type), t_method in iteritems(thumbnails): # Generate the thumbnail if t_method == "crop": - t_byte_source = yield logcontext.defer_to_thread( + t_byte_source = yield defer_to_thread( self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type ) elif t_method == "scale": - t_byte_source = yield logcontext.defer_to_thread( + t_byte_source = yield defer_to_thread( self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type ) else: diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index eff86836fb..25e5ac2848 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -24,9 +24,8 @@ import six from twisted.internet import defer from twisted.protocols.basic import FileSender -from synapse.util import logcontext +from synapse.logging.context import defer_to_thread, make_deferred_yieldable from synapse.util.file_consumer import BackgroundFileConsumer -from synapse.util.logcontext import make_deferred_yieldable from ._base import Responder @@ -65,7 +64,7 @@ class MediaStorage(object): with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository - yield logcontext.defer_to_thread( + yield defer_to_thread( self.hs.get_reactor(), _write_file_synchronously, source, f ) yield finish_cb() diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0337b64dc2..5871737bfd 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -42,11 +42,11 @@ from synapse.http.server import ( wrap_json_request_handler, ) from synapse.http.servlet import parse_integer, parse_string +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.stringutils import random_string from ._base import FileInfo @@ -95,6 +95,7 @@ class PreviewUrlResource(DirectServeResource): ) def render_OPTIONS(self, request): + request.setHeader(b"Allow", b"OPTIONS, GET") return respond_with_json(request, 200, {}, send_cors=True) @wrap_json_request_handler diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 359b45ebfc..37687ea7f4 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -20,8 +20,7 @@ import shutil from twisted.internet import defer from synapse.config._base import Config -from synapse.util import logcontext -from synapse.util.logcontext import run_in_background +from synapse.logging.context import defer_to_thread, run_in_background from .media_storage import FileResponder @@ -68,7 +67,7 @@ class StorageProviderWrapper(StorageProvider): backend (StorageProvider) store_local (bool): Whether to store new local files or not. store_synchronous (bool): Whether to wait for file to be successfully - uploaded, or todo the upload in the backgroud. + uploaded, or todo the upload in the background. store_remote (bool): Whether remote media should be uploaded """ @@ -125,7 +124,7 @@ class FileStorageProviderBackend(StorageProvider): if not os.path.exists(dirname): os.makedirs(dirname) - return logcontext.defer_to_thread( + return defer_to_thread( self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py index 939c87306c..69ecc5e4b4 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/saml2/response_resource.py @@ -13,17 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import saml2 -from saml2.client import Saml2Client - -from synapse.api.errors import CodeMessageException from synapse.http.server import DirectServeResource, wrap_html_request_handler -from synapse.http.servlet import parse_string -from synapse.rest.client.v1.login import SSOAuthHandler - -logger = logging.getLogger(__name__) class SAML2ResponseResource(DirectServeResource): @@ -33,32 +24,8 @@ class SAML2ResponseResource(DirectServeResource): def __init__(self, hs): super().__init__() - - self._saml_client = Saml2Client(hs.config.saml2_sp_config) - self._sso_auth_handler = SSOAuthHandler(hs) + self._saml_handler = hs.get_saml_handler() @wrap_html_request_handler async def _async_render_POST(self, request): - resp_bytes = parse_string(request, "SAMLResponse", required=True) - relay_state = parse_string(request, "RelayState", required=True) - - try: - saml2_auth = self._saml_client.parse_authn_request_response( - resp_bytes, saml2.BINDING_HTTP_POST - ) - except Exception as e: - logger.warning("Exception parsing SAML2 response", exc_info=1) - raise CodeMessageException(400, "Unable to parse SAML2 response: %s" % (e,)) - - if saml2_auth.not_signed: - raise CodeMessageException(400, "SAML2 response was not signed") - - if "uid" not in saml2_auth.ava: - raise CodeMessageException(400, "uid not in SAML2 response") - - username = saml2_auth.ava["uid"][0] - - displayName = saml2_auth.ava.get("displayName", [None])[0] - return self._sso_auth_handler.on_successful_auth( - username, request, relay_state, user_display_name=displayName - ) + return await self._saml_handler.handle_saml_response(request) |