From 7e15410f0269aaea204b59fb3f7891e2437331ec Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 31 May 2018 18:17:11 +0100 Subject: Enforce the specified API for report_event as per https://matrix.org/docs/spec/client_server/unstable.html#post-matrix-client-r0-rooms-roomid-report-eventid --- synapse/rest/client/v2_alpha/report_event.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index 8903e12405..8a38be6482 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -13,9 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six import string_types +from six.moves import http_client + from twisted.internet import defer -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.api.errors import SynapseError, Codes +from synapse.http.servlet import ( + RestServlet, parse_json_object_from_request, + assert_params_in_request, +) from ._base import client_v2_patterns import logging @@ -42,12 +49,26 @@ class ReportEventRestServlet(RestServlet): user_id = requester.user.to_string() body = parse_json_object_from_request(request) + assert_params_in_request(body, ("reason", "score")) + + if not isinstance(body["reason"], string_types): + raise SynapseError( + http_client.BAD_REQUEST, + "Param 'reason' must be a string", + Codes.BAD_JSON, + ) + if not isinstance(body["score"], int): + raise SynapseError( + http_client.BAD_REQUEST, + "Param 'score' must be an integer", + Codes.BAD_JSON, + ) yield self.store.add_event_report( room_id=room_id, event_id=event_id, user_id=user_id, - reason=body.get("reason"), + reason=body["reason"], content=body, received_ts=self.clock.time_msec(), ) -- cgit 1.4.1 From 32fd6910d0175f14d9be756ac2241da683af83bb Mon Sep 17 00:00:00 2001 From: Krombel Date: Fri, 13 Jul 2018 21:40:14 +0200 Subject: Use parse_{int,str} and assert from http.servlet parse_integer and parse_string can take a request and raise errors in case we have wrong or missing params. This PR tries to use them more to deduplicate some code and make it better readable --- synapse/rest/client/v1/admin.py | 47 +++++++++------------------ synapse/rest/client/v1/directory.py | 11 +++---- synapse/rest/client/v1/initial_sync.py | 3 +- synapse/rest/client/v1/push_rule.py | 6 ++-- synapse/rest/client/v1/pusher.py | 17 ++++------ synapse/rest/client/v1/register.py | 24 +++----------- synapse/rest/client/v1/room.py | 10 +++--- synapse/rest/client/v2_alpha/account.py | 40 +++++------------------ synapse/rest/client/v2_alpha/devices.py | 26 ++++++++------- synapse/rest/client/v2_alpha/register.py | 17 +++++----- synapse/rest/media/v1/identicon_resource.py | 6 ++-- synapse/rest/media/v1/preview_url_resource.py | 5 +-- synapse/rest/media/v1/upload_resource.py | 5 +-- synapse/streams/config.py | 29 ++++------------- 14 files changed, 91 insertions(+), 155 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 2b091d61a5..6b3c496418 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -22,7 +22,12 @@ from twisted.internet import defer from synapse.api.constants import Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import ( + assert_params_in_request, + parse_json_object_from_request, + parse_integer, + parse_string +) from synapse.types import UserID, create_requester from .base import ClientV1RestServlet, client_path_patterns @@ -98,16 +103,8 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet): if not is_admin: raise AuthError(403, "You are not a server admin") - before_ts = request.args.get("before_ts", None) - if not before_ts: - raise SynapseError(400, "Missing 'before_ts' arg") - - logger.info("before_ts: %r", before_ts[0]) - - try: - before_ts = int(before_ts[0]) - except Exception: - raise SynapseError(400, "Invalid 'before_ts' arg") + before_ts = parse_integer(request, "before_ts", required=True) + logger.info("before_ts: %r", before_ts) ret = yield self.media_repository.delete_old_remote_media(before_ts) @@ -300,10 +297,8 @@ class ShutdownRoomRestServlet(ClientV1RestServlet): raise AuthError(403, "You are not a server admin") content = parse_json_object_from_request(request) - - new_room_user_id = content.get("new_room_user_id") - if not new_room_user_id: - raise SynapseError(400, "Please provide field `new_room_user_id`") + assert_params_in_request(content, ["new_room_user_id"]) + new_room_user_id = content["new_room_user_id"] room_creator_requester = create_requester(new_room_user_id) @@ -464,9 +459,8 @@ class ResetPasswordRestServlet(ClientV1RestServlet): raise AuthError(403, "You are not a server admin") params = parse_json_object_from_request(request) + assert_params_in_request(params, ["new_password"]) new_password = params['new_password'] - if not new_password: - raise SynapseError(400, "Missing 'new_password' arg") logger.info("new_password: %r", new_password) @@ -514,12 +508,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): raise SynapseError(400, "Can only users a local user") order = "name" # order by name in user table - start = request.args.get("start")[0] - limit = request.args.get("limit")[0] - if not limit: - raise SynapseError(400, "Missing 'limit' arg") - if not start: - raise SynapseError(400, "Missing 'start' arg") + start = parse_integer(request, "start", required=True) + limit = parse_integer(request, "limit", required=True) + logger.info("limit: %s, start: %s", limit, start) ret = yield self.handlers.admin_handler.get_users_paginate( @@ -551,12 +542,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): order = "name" # order by name in user table params = parse_json_object_from_request(request) + assert_params_in_request(params, ["limit", "start"]) limit = params['limit'] start = params['start'] - if not limit: - raise SynapseError(400, "Missing 'limit' arg") - if not start: - raise SynapseError(400, "Missing 'start' arg") logger.info("limit: %s, start: %s", limit, start) ret = yield self.handlers.admin_handler.get_users_paginate( @@ -604,10 +592,7 @@ class SearchUsersRestServlet(ClientV1RestServlet): if not self.hs.is_mine(target_user): raise SynapseError(400, "Can only users a local user") - term = request.args.get("term")[0] - if not term: - raise SynapseError(400, "Missing 'term' arg") - + term = parse_string(request, "term", required=True) logger.info("term: %s ", term) ret = yield self.handlers.admin_handler.search_users( diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 4fdbb83815..3003cde94e 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -18,8 +18,8 @@ import logging from twisted.internet import defer -from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.http.servlet import parse_json_object_from_request +from synapse.api.errors import AuthError, SynapseError +from synapse.http.servlet import assert_params_in_request, parse_json_object_from_request from synapse.types import RoomAlias from .base import ClientV1RestServlet, client_path_patterns @@ -52,15 +52,14 @@ class ClientDirectoryServer(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_alias): + room_alias = RoomAlias.from_string(room_alias) + content = parse_json_object_from_request(request) if "room_id" not in content: - raise SynapseError(400, "Missing room_id key", + raise SynapseError(400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON) logger.debug("Got content: %s", content) - - room_alias = RoomAlias.from_string(room_alias) - logger.debug("Got room name: %s", room_alias.to_string()) room_id = content["room_id"] diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index fbe8cb2023..00a1a99feb 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -16,6 +16,7 @@ from twisted.internet import defer from synapse.streams.config import PaginationConfig +from synapse.http.servlet import parse_boolean from .base import ClientV1RestServlet, client_path_patterns @@ -33,7 +34,7 @@ class InitialSyncRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) - include_archived = request.args.get("archived", None) == ["true"] + include_archived = parse_boolean(request, "archived", default=False) content = yield self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), pagin_config=pagination_config, diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 0df7ce570f..7cf6a99774 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -21,7 +21,7 @@ from synapse.api.errors import ( SynapseError, UnrecognizedRequestError, ) -from synapse.http.servlet import parse_json_value_from_request +from synapse.http.servlet import parse_json_value_from_request, parse_string from synapse.push.baserules import BASE_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP @@ -75,11 +75,11 @@ class PushRuleRestServlet(ClientV1RestServlet): except InvalidRuleException as e: raise SynapseError(400, e.message) - before = request.args.get("before", None) + before = parse_string(request, "before") if before: before = _namespaced_rule_id(spec, before[0]) - after = request.args.get("after", None) + after = parse_string(request, "after") if after: after = _namespaced_rule_id(spec, after[0]) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 1581f88db5..95b9252b61 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -21,6 +21,7 @@ from synapse.api.errors import Codes, StoreError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, + assert_params_in_request, parse_json_object_from_request, parse_string, ) @@ -91,15 +92,11 @@ class PushersSetRestServlet(ClientV1RestServlet): ) defer.returnValue((200, {})) - reqd = ['kind', 'app_id', 'app_display_name', - 'device_display_name', 'pushkey', 'lang', 'data'] - missing = [] - for i in reqd: - if i not in content: - missing.append(i) - if len(missing): - raise SynapseError(400, "Missing parameters: " + ','.join(missing), - errcode=Codes.MISSING_PARAM) + assert_params_in_request( + content, + ['kind', 'app_id', 'app_display_name', + 'device_display_name', 'pushkey', 'lang', 'data'] + ) logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind']) logger.debug("Got pushers request with body: %r", content) @@ -148,7 +145,7 @@ class PushersRemoveRestServlet(RestServlet): SUCCESS_HTML = "You have been unsubscribed" def __init__(self, hs): - super(RestServlet, self).__init__() + super(PushersRemoveRestServlet, self).__init__() self.hs = hs self.notifier = hs.get_notifier() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 3ce5f8b726..744ed04455 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -18,15 +18,13 @@ import hmac import logging from hashlib import sha1 -from six import string_types - from twisted.internet import defer import synapse.util.stringutils as stringutils from synapse.api.auth import get_access_token_from_request from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import assert_params_in_request, parse_json_object_from_request from synapse.types import create_requester from .base import ClientV1RestServlet, client_path_patterns @@ -124,8 +122,7 @@ class RegisterRestServlet(ClientV1RestServlet): session = (register_json["session"] if "session" in register_json else None) login_type = None - if "type" not in register_json: - raise SynapseError(400, "Missing 'type' key.") + assert_params_in_request(register_json, ["type"]) try: login_type = register_json["type"] @@ -312,9 +309,7 @@ class RegisterRestServlet(ClientV1RestServlet): def _do_app_service(self, request, register_json, session): as_token = get_access_token_from_request(request) - if "user" not in register_json: - raise SynapseError(400, "Expected 'user' key.") - + assert_params_in_request(register_json, ["user"]) user_localpart = register_json["user"].encode("utf-8") handler = self.handlers.registration_handler @@ -331,12 +326,7 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_shared_secret(self, request, register_json, session): - if not isinstance(register_json.get("mac", None), string_types): - raise SynapseError(400, "Expected mac.") - if not isinstance(register_json.get("user", None), string_types): - raise SynapseError(400, "Expected 'user' key.") - if not isinstance(register_json.get("password", None), string_types): - raise SynapseError(400, "Expected 'password' key.") + assert_params_in_request(register_json, ["mac", "user", "password"]) if not self.hs.config.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") @@ -419,11 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_create(self, requester, user_json): - if "localpart" not in user_json: - raise SynapseError(400, "Expected 'localpart' key.") - - if "displayname" not in user_json: - raise SynapseError(400, "Expected 'displayname' key.") + assert_params_in_request(user_json, ["localpart", "displayname"]) localpart = user_json["localpart"].encode("utf-8") displayname = user_json["displayname"].encode("utf-8") diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 2470db52ba..3050d040a2 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -28,6 +28,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2, serialize_event from synapse.http.servlet import ( + assert_params_in_request, parse_integer, parse_json_object_from_request, parse_string, @@ -435,7 +436,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): request, default_limit=10, ) as_client_event = "raw" not in request.args - filter_bytes = request.args.get("filter", None) + filter_bytes = parse_string(request, "filter") if filter_bytes: filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8") event_filter = Filter(json.loads(filter_json)) @@ -530,7 +531,7 @@ class RoomEventContextServlet(ClientV1RestServlet): def on_GET(self, request, room_id, event_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - limit = int(request.args.get("limit", [10])[0]) + limit = parse_integer(request, "limit", default=10) results = yield self.handlers.room_context_handler.get_event_context( requester.user, @@ -636,8 +637,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): target = requester.user if membership_action in ["invite", "ban", "unban", "kick"]: - if "user_id" not in content: - raise SynapseError(400, "Missing user_id key.") + assert_params_in_request(content, ["user_id"]) target = UserID.from_string(content["user_id"]) event_content = None @@ -764,7 +764,7 @@ class SearchRestServlet(ClientV1RestServlet): content = parse_json_object_from_request(request) - batch = request.args.get("next_batch", [None])[0] + batch = parse_string(request, "next_batch") results = yield self.handlers.search_handler.search( requester.user, content, diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 528c1f43f9..2952472a4a 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -160,11 +160,10 @@ class PasswordRestServlet(RestServlet): raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) user_id = threepid_user_id else: - logger.error("Auth succeeded but no known type!", result.keys()) + logger.error("Auth succeeded but no known type! %r", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - if 'new_password' not in params: - raise SynapseError(400, "", Codes.MISSING_PARAM) + assert_params_in_request(params, ["new_password"]) new_password = params['new_password'] yield self._set_password_handler.set_password( @@ -229,15 +228,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - - required = ['id_server', 'client_secret', 'email', 'send_attempt'] - absent = [] - for k in required: - if k not in body: - absent.append(k) - - if absent: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + assert_params_in_request( + body, + ['id_server', 'client_secret', 'email', 'send_attempt'], + ) if not check_3pid_allowed(self.hs, "email", body['email']): raise SynapseError( @@ -267,18 +261,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - - required = [ + assert_params_in_request(body, [ 'id_server', 'client_secret', 'country', 'phone_number', 'send_attempt', - ] - absent = [] - for k in required: - if k not in body: - absent.append(k) - - if absent: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + ]) msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) @@ -373,15 +359,7 @@ class ThreepidDeleteRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - - required = ['medium', 'address'] - absent = [] - for k in required: - if k not in body: - absent.append(k) - - if absent: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + assert_params_in_request(body, ['medium', 'address']) requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 09f6a8efe3..e346ad4ed6 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -18,14 +18,18 @@ import logging from twisted.internet import defer from synapse.api import errors -from synapse.http import servlet +from synapse.http.servlet import ( + assert_params_in_request, + parse_json_object_from_request, + RestServlet +) from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) -class DevicesRestServlet(servlet.RestServlet): +class DevicesRestServlet(RestServlet): PATTERNS = client_v2_patterns("/devices$", v2_alpha=False) def __init__(self, hs): @@ -47,7 +51,7 @@ class DevicesRestServlet(servlet.RestServlet): defer.returnValue((200, {"devices": devices})) -class DeleteDevicesRestServlet(servlet.RestServlet): +class DeleteDevicesRestServlet(RestServlet): """ API for bulk deletion of devices. Accepts a JSON object with a devices key which lists the device_ids to delete. Requires user interactive auth. @@ -67,19 +71,17 @@ class DeleteDevicesRestServlet(servlet.RestServlet): requester = yield self.auth.get_user_by_req(request) try: - body = servlet.parse_json_object_from_request(request) + body = parse_json_object_from_request(request) except errors.SynapseError as e: if e.errcode == errors.Codes.NOT_JSON: - # deal with older clients which didn't pass a J*DELETESON dict + # DELETE + # deal with older clients which didn't pass a JSON dict # the same as those that pass an empty dict body = {} else: raise e - if 'devices' not in body: - raise errors.SynapseError( - 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM - ) + assert_params_in_request(body, ["devices"]) yield self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request), @@ -92,7 +94,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet): defer.returnValue((200, {})) -class DeviceRestServlet(servlet.RestServlet): +class DeviceRestServlet(RestServlet): PATTERNS = client_v2_patterns("/devices/(?P[^/]*)$", v2_alpha=False) def __init__(self, hs): @@ -121,7 +123,7 @@ class DeviceRestServlet(servlet.RestServlet): requester = yield self.auth.get_user_by_req(request) try: - body = servlet.parse_json_object_from_request(request) + body = parse_json_object_from_request(request) except errors.SynapseError as e: if e.errcode == errors.Codes.NOT_JSON: @@ -144,7 +146,7 @@ class DeviceRestServlet(servlet.RestServlet): def on_PUT(self, request, device_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - body = servlet.parse_json_object_from_request(request) + body = parse_json_object_from_request(request) yield self.device_handler.update_device( requester.user.to_string(), device_id, diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 896650d5a5..e2023e3a61 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -387,9 +387,7 @@ class RegisterRestServlet(RestServlet): add_msisdn = False else: # NB: This may be from the auth handler and NOT from the POST - if 'password' not in params: - raise SynapseError(400, "Missing password.", - Codes.MISSING_PARAM) + assert_params_in_request(params, ["password"]) desired_username = params.get("username", None) new_password = params.get("password", None) @@ -566,11 +564,14 @@ class RegisterRestServlet(RestServlet): Returns: defer.Deferred: """ - reqd = ('medium', 'address', 'validated_at') - if any(x not in threepid for x in reqd): - # This will only happen if the ID server returns a malformed response - logger.info("Can't add incomplete 3pid") - defer.returnValue() + try: + assert_params_in_request(threepid, ['medium', 'address', 'validated_at']) + except SynapseError as ex: + if ex.errcode == Codes.MISSING_PARAM: + # This will only happen if the ID server returns a malformed response + logger.info("Can't add incomplete 3pid") + defer.returnValue(None) + raise yield self.auth_handler.add_threepid( user_id, diff --git a/synapse/rest/media/v1/identicon_resource.py b/synapse/rest/media/v1/identicon_resource.py index a2e391415f..b3217eff53 100644 --- a/synapse/rest/media/v1/identicon_resource.py +++ b/synapse/rest/media/v1/identicon_resource.py @@ -14,6 +14,8 @@ from pydenticon import Generator +from synapse.http.servlet import parse_integer + from twisted.web.resource import Resource FOREGROUND = [ @@ -56,8 +58,8 @@ class IdenticonResource(Resource): def render_GET(self, request): name = "/".join(request.postpath) - width = int(request.args.get("width", [96])[0]) - height = int(request.args.get("height", [96])[0]) + width = parse_integer(request, "width", default=96) + height = parse_integer(request, "height", default=96) identicon_bytes = self.generate_identicon(name, width, height) request.setHeader(b"Content-Type", b"image/png") request.setHeader( diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 4e3a18ce08..b70b15c4c2 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -40,6 +40,7 @@ from synapse.http.server import ( respond_with_json_bytes, wrap_json_request_handler, ) +from synapse.http.servlet import parse_integer, parse_string from synapse.util.async import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logcontext import make_deferred_yieldable, run_in_background @@ -96,9 +97,9 @@ class PreviewUrlResource(Resource): # XXX: if get_user_by_req fails, what should we do in an async render? requester = yield self.auth.get_user_by_req(request) - url = request.args.get("url")[0] + url = parse_string(request, "url") if "ts" in request.args: - ts = int(request.args.get("ts")[0]) + ts = parse_integer(request, "ts") else: ts = self.clock.time_msec() diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 1a98120e1d..9b22d204a6 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -21,6 +21,7 @@ from twisted.web.server import NOT_DONE_YET from synapse.api.errors import SynapseError from synapse.http.server import respond_with_json, wrap_json_request_handler +from synapse.http.servlet import parse_string logger = logging.getLogger(__name__) @@ -65,10 +66,10 @@ class UploadResource(Resource): code=413, ) - upload_name = request.args.get("filename", None) + upload_name = parse_string(request, "filename") if upload_name: try: - upload_name = upload_name[0].decode('UTF-8') + upload_name = upload_name.decode('UTF-8') except UnicodeDecodeError: raise SynapseError( msg="Invalid UTF-8 filename parameter: %r" % (upload_name), diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 46ccbbda7d..22b0b830f7 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -16,6 +16,7 @@ import logging from synapse.api.errors import SynapseError +from synapse.http.servlet import parse_integer, parse_string from synapse.types import StreamToken logger = logging.getLogger(__name__) @@ -56,23 +57,10 @@ class PaginationConfig(object): @classmethod def from_request(cls, request, raise_invalid_params=True, default_limit=None): - def get_param(name, default=None): - lst = request.args.get(name, []) - if len(lst) > 1: - raise SynapseError( - 400, "%s must be specified only once" % (name,) - ) - elif len(lst) == 1: - return lst[0] - else: - return default - - direction = get_param("dir", 'f') - if direction not in ['f', 'b']: - raise SynapseError(400, "'dir' parameter is invalid.") - - from_tok = get_param("from") - to_tok = get_param("to") + direction = parse_string(request, "dir", default='f', allowed_values=['f', 'b']) + + from_tok = parse_string(request, "from") + to_tok = parse_string(request, "to") try: if from_tok == "END": @@ -88,12 +76,7 @@ class PaginationConfig(object): except Exception: raise SynapseError(400, "'to' paramater is invalid") - limit = get_param("limit", None) - if limit is not None and not limit.isdigit(): - raise SynapseError(400, "'limit' parameter must be an integer.") - - if limit is None: - limit = default_limit + limit = parse_integer(request, "limit", default=default_limit) try: return PaginationConfig(from_tok, to_tok, direction, limit) -- cgit 1.4.1 From 3366b9c534cc286fcbe218ab718fd9a9c6211da4 Mon Sep 17 00:00:00 2001 From: Krombel Date: Fri, 13 Jul 2018 21:53:01 +0200 Subject: rename assert_params_in_request to assert_params_in_dict the method "assert_params_in_request" does handle dicts and not requests. A request body has to be parsed to json before this method can be used --- synapse/federation/federation_base.py | 4 ++-- synapse/http/servlet.py | 2 +- synapse/rest/client/v1/admin.py | 8 ++++---- synapse/rest/client/v1/directory.py | 2 +- synapse/rest/client/v1/pusher.py | 4 ++-- synapse/rest/client/v1/register.py | 10 +++++----- synapse/rest/client/v1/room.py | 4 ++-- synapse/rest/client/v2_alpha/account.py | 14 +++++++------- synapse/rest/client/v2_alpha/devices.py | 4 ++-- synapse/rest/client/v2_alpha/register.py | 10 +++++----- synapse/rest/client/v2_alpha/report_event.py | 4 ++-- 11 files changed, 33 insertions(+), 33 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index f0c7a06718..c11798093d 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -23,7 +23,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.crypto.event_signing import check_event_content_hash from synapse.events import FrozenEvent from synapse.events.utils import prune_event -from synapse.http.servlet import assert_params_in_request +from synapse.http.servlet import assert_params_in_dict from synapse.util import logcontext, unwrapFirstError logger = logging.getLogger(__name__) @@ -199,7 +199,7 @@ def event_from_pdu_json(pdu_json, outlier=False): """ # we could probably enforce a bunch of other fields here (room_id, sender, # origin, etc etc) - assert_params_in_request(pdu_json, ('event_id', 'type', 'depth')) + assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth')) depth = pdu_json['depth'] if not isinstance(depth, six.integer_types): diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index cf6723563a..882816dc8f 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -206,7 +206,7 @@ def parse_json_object_from_request(request, allow_empty_body=False): return content -def assert_params_in_request(body, required): +def assert_params_in_dict(body, required): absent = [] for k in required: if k not in body: diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 6b3c496418..01c3f2eb04 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -23,7 +23,7 @@ from twisted.internet import defer from synapse.api.constants import Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import ( - assert_params_in_request, + assert_params_in_dict, parse_json_object_from_request, parse_integer, parse_string @@ -297,7 +297,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet): raise AuthError(403, "You are not a server admin") content = parse_json_object_from_request(request) - assert_params_in_request(content, ["new_room_user_id"]) + assert_params_in_dict(content, ["new_room_user_id"]) new_room_user_id = content["new_room_user_id"] room_creator_requester = create_requester(new_room_user_id) @@ -459,7 +459,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet): raise AuthError(403, "You are not a server admin") params = parse_json_object_from_request(request) - assert_params_in_request(params, ["new_password"]) + assert_params_in_dict(params, ["new_password"]) new_password = params['new_password'] logger.info("new_password: %r", new_password) @@ -542,7 +542,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): order = "name" # order by name in user table params = parse_json_object_from_request(request) - assert_params_in_request(params, ["limit", "start"]) + assert_params_in_dict(params, ["limit", "start"]) limit = params['limit'] start = params['start'] logger.info("limit: %s, start: %s", limit, start) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 3003cde94e..c320612aea 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -19,7 +19,7 @@ import logging from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError -from synapse.http.servlet import assert_params_in_request, parse_json_object_from_request +from synapse.http.servlet import parse_json_object_from_request from synapse.types import RoomAlias from .base import ClientV1RestServlet, client_path_patterns diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 95b9252b61..182a68b1e2 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -21,7 +21,7 @@ from synapse.api.errors import Codes, StoreError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, - assert_params_in_request, + assert_params_in_dict, parse_json_object_from_request, parse_string, ) @@ -92,7 +92,7 @@ class PushersSetRestServlet(ClientV1RestServlet): ) defer.returnValue((200, {})) - assert_params_in_request( + assert_params_in_dict( content, ['kind', 'app_id', 'app_display_name', 'device_display_name', 'pushkey', 'lang', 'data'] diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 744ed04455..b2d5baccdc 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -24,7 +24,7 @@ import synapse.util.stringutils as stringutils from synapse.api.auth import get_access_token_from_request from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import assert_params_in_request, parse_json_object_from_request +from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request from synapse.types import create_requester from .base import ClientV1RestServlet, client_path_patterns @@ -122,7 +122,7 @@ class RegisterRestServlet(ClientV1RestServlet): session = (register_json["session"] if "session" in register_json else None) login_type = None - assert_params_in_request(register_json, ["type"]) + assert_params_in_dict(register_json, ["type"]) try: login_type = register_json["type"] @@ -309,7 +309,7 @@ class RegisterRestServlet(ClientV1RestServlet): def _do_app_service(self, request, register_json, session): as_token = get_access_token_from_request(request) - assert_params_in_request(register_json, ["user"]) + assert_params_in_dict(register_json, ["user"]) user_localpart = register_json["user"].encode("utf-8") handler = self.handlers.registration_handler @@ -326,7 +326,7 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_shared_secret(self, request, register_json, session): - assert_params_in_request(register_json, ["mac", "user", "password"]) + assert_params_in_dict(register_json, ["mac", "user", "password"]) if not self.hs.config.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") @@ -409,7 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_create(self, requester, user_json): - assert_params_in_request(user_json, ["localpart", "displayname"]) + assert_params_in_dict(user_json, ["localpart", "displayname"]) localpart = user_json["localpart"].encode("utf-8") displayname = user_json["displayname"].encode("utf-8") diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 3050d040a2..5a867b4cf2 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -28,7 +28,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2, serialize_event from synapse.http.servlet import ( - assert_params_in_request, + assert_params_in_dict, parse_integer, parse_json_object_from_request, parse_string, @@ -637,7 +637,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): target = requester.user if membership_action in ["invite", "ban", "unban", "kick"]: - assert_params_in_request(content, ["user_id"]) + assert_params_in_dict(content, ["user_id"]) target = UserID.from_string(content["user_id"]) event_content = None diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 2952472a4a..28a5e76ae3 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -25,7 +25,7 @@ from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( RestServlet, - assert_params_in_request, + assert_params_in_dict, parse_json_object_from_request, ) from synapse.util.msisdn import phone_number_to_msisdn @@ -48,7 +48,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_request(body, [ + assert_params_in_dict(body, [ 'id_server', 'client_secret', 'email', 'send_attempt' ]) @@ -81,7 +81,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_request(body, [ + assert_params_in_dict(body, [ 'id_server', 'client_secret', 'country', 'phone_number', 'send_attempt', ]) @@ -163,7 +163,7 @@ class PasswordRestServlet(RestServlet): logger.error("Auth succeeded but no known type! %r", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - assert_params_in_request(params, ["new_password"]) + assert_params_in_dict(params, ["new_password"]) new_password = params['new_password'] yield self._set_password_handler.set_password( @@ -228,7 +228,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_request( + assert_params_in_dict( body, ['id_server', 'client_secret', 'email', 'send_attempt'], ) @@ -261,7 +261,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_request(body, [ + assert_params_in_dict(body, [ 'id_server', 'client_secret', 'country', 'phone_number', 'send_attempt', ]) @@ -359,7 +359,7 @@ class ThreepidDeleteRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_request(body, ['medium', 'address']) + assert_params_in_dict(body, ['medium', 'address']) requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index e346ad4ed6..aded2409be 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.api import errors from synapse.http.servlet import ( - assert_params_in_request, + assert_params_in_dict, parse_json_object_from_request, RestServlet ) @@ -81,7 +81,7 @@ class DeleteDevicesRestServlet(RestServlet): else: raise e - assert_params_in_request(body, ["devices"]) + assert_params_in_dict(body, ["devices"]) yield self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request), diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index e2023e3a61..0d0eeb8c0b 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -29,7 +29,7 @@ from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError from synapse.http.servlet import ( RestServlet, - assert_params_in_request, + assert_params_in_dict, parse_json_object_from_request, parse_string, ) @@ -69,7 +69,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_request(body, [ + assert_params_in_dict(body, [ 'id_server', 'client_secret', 'email', 'send_attempt' ]) @@ -105,7 +105,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_request(body, [ + assert_params_in_dict(body, [ 'id_server', 'client_secret', 'country', 'phone_number', 'send_attempt', @@ -387,7 +387,7 @@ class RegisterRestServlet(RestServlet): add_msisdn = False else: # NB: This may be from the auth handler and NOT from the POST - assert_params_in_request(params, ["password"]) + assert_params_in_dict(params, ["password"]) desired_username = params.get("username", None) new_password = params.get("password", None) @@ -565,7 +565,7 @@ class RegisterRestServlet(RestServlet): defer.Deferred: """ try: - assert_params_in_request(threepid, ['medium', 'address', 'validated_at']) + assert_params_in_dict(threepid, ['medium', 'address', 'validated_at']) except SynapseError as ex: if ex.errcode == Codes.MISSING_PARAM: # This will only happen if the ID server returns a malformed response diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index 0cc2a71c3b..95d2a71ec2 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -23,7 +23,7 @@ from twisted.internet import defer from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( RestServlet, - assert_params_in_request, + assert_params_in_dict, parse_json_object_from_request, ) @@ -50,7 +50,7 @@ class ReportEventRestServlet(RestServlet): user_id = requester.user.to_string() body = parse_json_object_from_request(request) - assert_params_in_request(body, ("reason", "score")) + assert_params_in_dict(body, ("reason", "score")) if not isinstance(body["reason"], string_types): raise SynapseError( -- cgit 1.4.1 From 516f960ad8f4ae8b4fe3740baac9a63a966a2642 Mon Sep 17 00:00:00 2001 From: Krombel Date: Fri, 13 Jul 2018 22:19:19 +0200 Subject: add changelog --- changelog.d/3534.misc | 1 + synapse/rest/client/v1/directory.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/3534.misc (limited to 'synapse/rest/client') diff --git a/changelog.d/3534.misc b/changelog.d/3534.misc new file mode 100644 index 0000000000..949c12dc69 --- /dev/null +++ b/changelog.d/3534.misc @@ -0,0 +1 @@ +refactor: use parse_{string,integer} and assert's from http.servlet for deduplication diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index c320612aea..69dcd618cb 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -18,7 +18,7 @@ import logging from twisted.internet import defer -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import AuthError, Codes, SynapseError from synapse.http.servlet import parse_json_object_from_request from synapse.types import RoomAlias -- cgit 1.4.1 From 33b60c01b5e60dc38b02b6f0fd018be2d2d07cf7 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 14 Jul 2018 07:34:49 +1000 Subject: Make auth & transactions more testable (#3499) --- changelog.d/3499.misc | 0 synapse/api/auth.py | 124 +++++++++++++-------------- synapse/rest/client/transactions.py | 45 +++++----- synapse/rest/client/v1/base.py | 2 +- synapse/rest/client/v1/logout.py | 3 +- synapse/rest/client/v1/register.py | 6 +- synapse/rest/client/v2_alpha/account.py | 3 +- synapse/rest/client/v2_alpha/register.py | 5 +- synapse/rest/client/v2_alpha/sendtodevice.py | 2 +- tests/rest/client/test_transactions.py | 5 +- 10 files changed, 97 insertions(+), 98 deletions(-) create mode 100644 changelog.d/3499.misc (limited to 'synapse/rest/client') diff --git a/changelog.d/3499.misc b/changelog.d/3499.misc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 6dec862fec..bc629832d9 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -193,7 +193,7 @@ class Auth(object): synapse.types.create_requester(user_id, app_service=app_service) ) - access_token = get_access_token_from_request( + access_token = self.get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS ) @@ -239,7 +239,7 @@ class Auth(object): @defer.inlineCallbacks def _get_appservice_user_id(self, request): app_service = self.store.get_app_service_by_token( - get_access_token_from_request( + self.get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS ) ) @@ -513,7 +513,7 @@ class Auth(object): def get_appservice_by_req(self, request): try: - token = get_access_token_from_request( + token = self.get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS ) service = self.store.get_app_service_by_token(token) @@ -673,67 +673,67 @@ class Auth(object): " edit its room list entry" ) + @staticmethod + def has_access_token(request): + """Checks if the request has an access_token. -def has_access_token(request): - """Checks if the request has an access_token. + Returns: + bool: False if no access_token was given, True otherwise. + """ + query_params = request.args.get("access_token") + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + return bool(query_params) or bool(auth_headers) - Returns: - bool: False if no access_token was given, True otherwise. - """ - query_params = request.args.get("access_token") - auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") - return bool(query_params) or bool(auth_headers) - - -def get_access_token_from_request(request, token_not_found_http_status=401): - """Extracts the access_token from the request. - - Args: - request: The http request. - token_not_found_http_status(int): The HTTP status code to set in the - AuthError if the token isn't found. This is used in some of the - legacy APIs to change the status code to 403 from the default of - 401 since some of the old clients depended on auth errors returning - 403. - Returns: - str: The access_token - Raises: - AuthError: If there isn't an access_token in the request. - """ + @staticmethod + def get_access_token_from_request(request, token_not_found_http_status=401): + """Extracts the access_token from the request. - auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") - query_params = request.args.get(b"access_token") - if auth_headers: - # Try the get the access_token from a "Authorization: Bearer" - # header - if query_params is not None: - raise AuthError( - token_not_found_http_status, - "Mixing Authorization headers and access_token query parameters.", - errcode=Codes.MISSING_TOKEN, - ) - if len(auth_headers) > 1: - raise AuthError( - token_not_found_http_status, - "Too many Authorization headers.", - errcode=Codes.MISSING_TOKEN, - ) - parts = auth_headers[0].split(" ") - if parts[0] == "Bearer" and len(parts) == 2: - return parts[1] + Args: + request: The http request. + token_not_found_http_status(int): The HTTP status code to set in the + AuthError if the token isn't found. This is used in some of the + legacy APIs to change the status code to 403 from the default of + 401 since some of the old clients depended on auth errors returning + 403. + Returns: + str: The access_token + Raises: + AuthError: If there isn't an access_token in the request. + """ + + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + query_params = request.args.get(b"access_token") + if auth_headers: + # Try the get the access_token from a "Authorization: Bearer" + # header + if query_params is not None: + raise AuthError( + token_not_found_http_status, + "Mixing Authorization headers and access_token query parameters.", + errcode=Codes.MISSING_TOKEN, + ) + if len(auth_headers) > 1: + raise AuthError( + token_not_found_http_status, + "Too many Authorization headers.", + errcode=Codes.MISSING_TOKEN, + ) + parts = auth_headers[0].split(" ") + if parts[0] == "Bearer" and len(parts) == 2: + return parts[1] + else: + raise AuthError( + token_not_found_http_status, + "Invalid Authorization header.", + errcode=Codes.MISSING_TOKEN, + ) else: - raise AuthError( - token_not_found_http_status, - "Invalid Authorization header.", - errcode=Codes.MISSING_TOKEN, - ) - else: - # Try to get the access_token from the query params. - if not query_params: - raise AuthError( - token_not_found_http_status, - "Missing access token.", - errcode=Codes.MISSING_TOKEN - ) + # Try to get the access_token from the query params. + if not query_params: + raise AuthError( + token_not_found_http_status, + "Missing access token.", + errcode=Codes.MISSING_TOKEN + ) - return query_params[0] + return query_params[0] diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 7c01b438cb..00b1b3066e 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -17,38 +17,20 @@ to ensure idempotency when performing PUTs using the REST API.""" import logging -from synapse.api.auth import get_access_token_from_request from synapse.util.async import ObservableDeferred from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) - -def get_transaction_key(request): - """A helper function which returns a transaction key that can be used - with TransactionCache for idempotent requests. - - Idempotency is based on the returned key being the same for separate - requests to the same endpoint. The key is formed from the HTTP request - path and the access_token for the requesting user. - - Args: - request (twisted.web.http.Request): The incoming request. Must - contain an access_token. - Returns: - str: A transaction key - """ - token = get_access_token_from_request(request) - return request.path + "/" + token - - CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins class HttpTransactionCache(object): - def __init__(self, clock): - self.clock = clock + def __init__(self, hs): + self.hs = hs + self.auth = self.hs.get_auth() + self.clock = self.hs.get_clock() self.transactions = { # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) } @@ -56,6 +38,23 @@ class HttpTransactionCache(object): # for at *LEAST* 30 mins, and at *MOST* 60 mins. self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) + def _get_transaction_key(self, request): + """A helper function which returns a transaction key that can be used + with TransactionCache for idempotent requests. + + Idempotency is based on the returned key being the same for separate + requests to the same endpoint. The key is formed from the HTTP request + path and the access_token for the requesting user. + + Args: + request (twisted.web.http.Request): The incoming request. Must + contain an access_token. + Returns: + str: A transaction key + """ + token = self.auth.get_access_token_from_request(request) + return request.path + "/" + token + def fetch_or_execute_request(self, request, fn, *args, **kwargs): """A helper function for fetch_or_execute which extracts a transaction key from the given request. @@ -64,7 +63,7 @@ class HttpTransactionCache(object): fetch_or_execute """ return self.fetch_or_execute( - get_transaction_key(request), fn, *args, **kwargs + self._get_transaction_key(request), fn, *args, **kwargs ) def fetch_or_execute(self, txn_key, fn, *args, **kwargs): diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index dde02328c3..c77d7aba68 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -62,4 +62,4 @@ class ClientV1RestServlet(RestServlet): self.hs = hs self.builder_factory = hs.get_event_builder_factory() self.auth = hs.get_auth() - self.txns = HttpTransactionCache(hs.get_clock()) + self.txns = HttpTransactionCache(hs) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 05a8ecfcd8..430c692336 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -17,7 +17,6 @@ import logging from twisted.internet import defer -from synapse.api.auth import get_access_token_from_request from synapse.api.errors import AuthError from .base import ClientV1RestServlet, client_path_patterns @@ -51,7 +50,7 @@ class LogoutRestServlet(ClientV1RestServlet): if requester.device_id is None: # the acccess token wasn't associated with a device. # Just delete the access token - access_token = get_access_token_from_request(request) + access_token = self._auth.get_access_token_from_request(request) yield self._auth_handler.delete_access_token(access_token) else: yield self._device_handler.delete_device( diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 3ce5f8b726..9203e60615 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -23,7 +23,6 @@ from six import string_types from twisted.internet import defer import synapse.util.stringutils as stringutils -from synapse.api.auth import get_access_token_from_request from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import parse_json_object_from_request @@ -67,6 +66,7 @@ class RegisterRestServlet(ClientV1RestServlet): # TODO: persistent storage self.sessions = {} self.enable_registration = hs.config.enable_registration + self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.handlers = hs.get_handlers() @@ -310,7 +310,7 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_app_service(self, request, register_json, session): - as_token = get_access_token_from_request(request) + as_token = self.auth.get_access_token_from_request(request) if "user" not in register_json: raise SynapseError(400, "Expected 'user' key.") @@ -400,7 +400,7 @@ class CreateUserRestServlet(ClientV1RestServlet): def on_POST(self, request): user_json = parse_json_object_from_request(request) - access_token = get_access_token_from_request(request) + access_token = self.auth.get_access_token_from_request(request) app_service = self.store.get_app_service_by_token( access_token ) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 528c1f43f9..1263abb02c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -20,7 +20,6 @@ from six.moves import http_client from twisted.internet import defer -from synapse.api.auth import has_access_token from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( @@ -130,7 +129,7 @@ class PasswordRestServlet(RestServlet): # # In the second case, we require a password to confirm their identity. - if has_access_token(request): + if self.auth.has_access_token(request): requester = yield self.auth.get_user_by_req(request) params = yield self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request), diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 896650d5a5..d305afcaef 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -24,7 +24,6 @@ from twisted.internet import defer import synapse import synapse.types -from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError from synapse.http.servlet import ( @@ -224,7 +223,7 @@ class RegisterRestServlet(RestServlet): desired_username = body['username'] appservice = None - if has_access_token(request): + if self.auth.has_access_token(request): appservice = yield self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes and shared secret auth which @@ -242,7 +241,7 @@ class RegisterRestServlet(RestServlet): # because the IRC bridges rely on being able to register stupid # IDs. - access_token = get_access_token_from_request(request) + access_token = self.auth.get_access_token_from_request(request) if isinstance(desired_username, string_types): result = yield self._do_appservice_registration( diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index 90bdb1db15..a9e9a47a0b 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -40,7 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): super(SendToDeviceRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.txns = HttpTransactionCache(hs.get_clock()) + self.txns = HttpTransactionCache(hs) self.device_message_handler = hs.get_device_message_handler() def on_PUT(self, request, message_type, txn_id): diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index eee99ca2e0..34e68ae82f 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -14,7 +14,10 @@ class HttpTransactionCacheTestCase(unittest.TestCase): def setUp(self): self.clock = MockClock() - self.cache = HttpTransactionCache(self.clock) + self.hs = Mock() + self.hs.get_clock = Mock(return_value=self.clock) + self.hs.get_auth = Mock() + self.cache = HttpTransactionCache(self.hs) self.mock_http_response = (200, "GOOD JOB!") self.mock_key = "foo" -- cgit 1.4.1 From a2374b2c7ff7753ecc02906fcbe1b2fb20e0d0fb Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 14 Jul 2018 07:52:58 +1000 Subject: fix sytests --- synapse/rest/client/v1/room.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 5a867b4cf2..3d62447854 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -438,7 +438,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): as_client_event = "raw" not in request.args filter_bytes = parse_string(request, "filter") if filter_bytes: - filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8") + filter_json = urlparse.unquote(filter_bytes).decode("UTF-8") event_filter = Filter(json.loads(filter_json)) else: event_filter = None -- cgit 1.4.1 From 8a4f05fefbeff7ff984c0c81a66891edc455c85b Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 14 Jul 2018 09:51:00 +1000 Subject: Fix develop because I broke it :( (#3535) --- changelog.d/3535.misc | 0 synapse/rest/client/v1/push_rule.py | 4 ++-- synapse/streams/config.py | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 changelog.d/3535.misc (limited to 'synapse/rest/client') diff --git a/changelog.d/3535.misc b/changelog.d/3535.misc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 7cf6a99774..6e95d9bec2 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -77,11 +77,11 @@ class PushRuleRestServlet(ClientV1RestServlet): before = parse_string(request, "before") if before: - before = _namespaced_rule_id(spec, before[0]) + before = _namespaced_rule_id(spec, before) after = parse_string(request, "after") if after: - after = _namespaced_rule_id(spec, after[0]) + after = _namespaced_rule_id(spec, after) try: yield self.store.add_push_rule( diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 22b0b830f7..451e4fa441 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -78,6 +78,9 @@ class PaginationConfig(object): limit = parse_integer(request, "limit", default=default_limit) + if limit and limit < 0: + raise SynapseError(400, "Limit must be 0 or above") + try: return PaginationConfig(from_tok, to_tok, direction, limit) except Exception: -- cgit 1.4.1 From 4a27000548d2b5194256152305fe121f671f0c81 Mon Sep 17 00:00:00 2001 From: Krombel Date: Mon, 16 Jul 2018 13:46:49 +0200 Subject: check isort by travis --- .travis.yml | 3 +++ changelog.d/3540.misc | 1 + synapse/config/consent_config.py | 1 + synapse/config/logger.py | 1 + synapse/config/server_notices_config.py | 1 + synapse/http/client.py | 4 +++- synapse/http/site.py | 2 +- synapse/rest/__init__.py | 16 +++++++++++++--- synapse/rest/client/v1/admin.py | 4 ++-- synapse/rest/client/v1/initial_sync.py | 2 +- synapse/rest/client/v2_alpha/devices.py | 2 +- synapse/rest/media/v1/identicon_resource.py | 4 ++-- synapse/storage/schema/delta/48/group_unique_indexes.py | 1 + tox.ini | 8 ++++++-- 14 files changed, 37 insertions(+), 13 deletions(-) create mode 100644 changelog.d/3540.misc (limited to 'synapse/rest/client') diff --git a/.travis.yml b/.travis.yml index a98d547978..e7fcd8b5d1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -23,6 +23,9 @@ matrix: - python: 3.6 env: TOX_ENV=py36 + - python: 3.6 + env: TOX_ENV=isort + - python: 3.6 env: TOX_ENV=check-newsfragment diff --git a/changelog.d/3540.misc b/changelog.d/3540.misc new file mode 100644 index 0000000000..99dcad8e46 --- /dev/null +++ b/changelog.d/3540.misc @@ -0,0 +1 @@ +check isort for each PR diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py index e22c731aad..cec978125b 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py @@ -15,6 +15,7 @@ from ._base import Config + DEFAULT_CONFIG = """\ # User Consent configuration # diff --git a/synapse/config/logger.py b/synapse/config/logger.py index a87b11a1df..cc1051057c 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -29,6 +29,7 @@ from synapse.util.versionstring import get_version_string from ._base import Config + DEFAULT_LOG_CONFIG = Template(""" version: 1 diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py index 3c39850ac6..91b08446b7 100644 --- a/synapse/config/server_notices_config.py +++ b/synapse/config/server_notices_config.py @@ -16,6 +16,7 @@ from synapse.types import UserID from ._base import Config + DEFAULT_CONFIG = """\ # Server Notices room configuration # diff --git a/synapse/http/client.py b/synapse/http/client.py index d6a0d75b2b..0622a69e86 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -26,9 +26,11 @@ from OpenSSL.SSL import VERIFY_NONE from twisted.internet import defer, protocol, reactor, ssl, task from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.web._newclient import ResponseDone -from twisted.web.client import Agent, BrowserLikeRedirectAgent, ContentDecoderAgent from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer from twisted.web.client import ( + Agent, + BrowserLikeRedirectAgent, + ContentDecoderAgent, GzipDecoder, HTTPConnectionPool, PartialDownloadError, diff --git a/synapse/http/site.py b/synapse/http/site.py index 21e26f9c5e..41dd974cea 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -20,7 +20,7 @@ from twisted.web.server import Request, Site from synapse.http import redact_uri from synapse.http.request_metrics import RequestMetrics -from synapse.util.logcontext import LoggingContext, ContextResourceUsage +from synapse.util.logcontext import ContextResourceUsage, LoggingContext logger = logging.getLogger(__name__) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 75c2a4ec8e..0d69d37f7e 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -15,11 +15,21 @@ from synapse.http.server import JsonResource from synapse.rest.client import versions -from synapse.rest.client.v1 import admin, directory, events, initial_sync from synapse.rest.client.v1 import login as v1_login -from synapse.rest.client.v1 import logout, presence, profile, push_rule, pusher from synapse.rest.client.v1 import register as v1_register -from synapse.rest.client.v1 import room, voip +from synapse.rest.client.v1 import ( + admin, + directory, + events, + initial_sync, + logout, + presence, + profile, + push_rule, + pusher, + room, + voip, +) from synapse.rest.client.v2_alpha import ( account, account_data, diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 01c3f2eb04..2dc50e582b 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -24,9 +24,9 @@ from synapse.api.constants import Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import ( assert_params_in_dict, - parse_json_object_from_request, parse_integer, - parse_string + parse_json_object_from_request, + parse_string, ) from synapse.types import UserID, create_requester diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 00a1a99feb..fd5f85b53e 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -15,8 +15,8 @@ from twisted.internet import defer -from synapse.streams.config import PaginationConfig from synapse.http.servlet import parse_boolean +from synapse.streams.config import PaginationConfig from .base import ClientV1RestServlet, client_path_patterns diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index aded2409be..9b75bb1377 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -19,9 +19,9 @@ from twisted.internet import defer from synapse.api import errors from synapse.http.servlet import ( + RestServlet, assert_params_in_dict, parse_json_object_from_request, - RestServlet ) from ._base import client_v2_patterns, interactive_auth_handler diff --git a/synapse/rest/media/v1/identicon_resource.py b/synapse/rest/media/v1/identicon_resource.py index b3217eff53..bdbd8d50dd 100644 --- a/synapse/rest/media/v1/identicon_resource.py +++ b/synapse/rest/media/v1/identicon_resource.py @@ -14,10 +14,10 @@ from pydenticon import Generator -from synapse.http.servlet import parse_integer - from twisted.web.resource import Resource +from synapse.http.servlet import parse_integer + FOREGROUND = [ "rgb(45,79,255)", "rgb(254,180,44)", diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py index 2233af87d7..ac4ac66cd4 100644 --- a/synapse/storage/schema/delta/48/group_unique_indexes.py +++ b/synapse/storage/schema/delta/48/group_unique_indexes.py @@ -15,6 +15,7 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.prepare_database import get_statements + FIX_INDEXES = """ -- rebuild indexes as uniques DROP INDEX groups_invites_g_idx; diff --git a/tox.ini b/tox.ini index 61a20a10cb..c7491690c5 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = packaging, py27, py36, pep8 +envlist = packaging, py27, py36, pep8, isort [testenv] deps = @@ -103,10 +103,14 @@ deps = flake8 commands = /bin/sh -c "flake8 synapse tests {env:PEP8SUFFIX:}" +[testenv:isort] +skip_install = True +deps = isort +commands = /bin/sh -c "isort -c -sp setup.cfg -rc synapse tests" [testenv:check-newsfragment] skip_install = True deps = towncrier>=18.6.0rc1 commands = python -m towncrier.check --compare-with=origin/develop -basepython = python3.6 \ No newline at end of file +basepython = python3.6 -- cgit 1.4.1 From bc006b3c9d2a9982bc834ff5d1ec1768c85f907a Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 17 Jul 2018 20:43:18 +1000 Subject: Refactor REST API tests to use explicit reactors (#3351) --- changelog.d/3351.misc | 0 synapse/http/site.py | 5 +- synapse/rest/client/v2_alpha/register.py | 2 +- tests/rest/client/v1/test_register.py | 68 +- tests/rest/client/v1/test_rooms.py | 1156 +++++++++++---------------- tests/rest/client/v1/utils.py | 115 ++- tests/rest/client/v2_alpha/__init__.py | 61 -- tests/rest/client/v2_alpha/test_filter.py | 129 +-- tests/rest/client/v2_alpha/test_register.py | 211 ++--- tests/rest/client/v2_alpha/test_sync.py | 83 ++ tests/server.py | 10 + tests/test_server.py | 22 +- tests/unittest.py | 11 + 13 files changed, 939 insertions(+), 934 deletions(-) create mode 100644 changelog.d/3351.misc create mode 100644 tests/rest/client/v2_alpha/test_sync.py (limited to 'synapse/rest/client') diff --git a/changelog.d/3351.misc b/changelog.d/3351.misc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/synapse/http/site.py b/synapse/http/site.py index 41dd974cea..5fd30a4c2c 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -42,9 +42,10 @@ class SynapseRequest(Request): which is handling the request, and returns a context manager. """ - def __init__(self, site, *args, **kw): - Request.__init__(self, *args, **kw) + def __init__(self, site, channel, *args, **kw): + Request.__init__(self, channel, *args, **kw) self.site = site + self._channel = channel self.authenticated_entity = None self.start_time = 0 diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 1918897f68..d6cf915d86 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -643,7 +643,7 @@ class RegisterRestServlet(RestServlet): @defer.inlineCallbacks def _do_guest_registration(self, params): if not self.hs.config.allow_guest_access: - defer.returnValue((403, "Guest access is disabled")) + raise SynapseError(403, "Guest access is disabled") user_id, _ = yield self.registration_handler.register( generate_token=False, make_guest=True diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index f596acb85f..f15fb36213 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -17,26 +17,22 @@ import json from mock import Mock -from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactorClock -from synapse.rest.client.v1.register import CreateUserRestServlet +from synapse.http.server import JsonResource +from synapse.rest.client.v1.register import register_servlets +from synapse.util import Clock from tests import unittest -from tests.utils import mock_getRawHeaders +from tests.server import make_request, setup_test_homeserver class CreateUserServletTestCase(unittest.TestCase): + """ + Tests for CreateUserRestServlet. + """ def setUp(self): - # do the dance to hook up request data to self.request_data - self.request_data = "" - self.request = Mock( - content=Mock(read=Mock(side_effect=lambda: self.request_data)), - path='/_matrix/client/api/v1/createUser' - ) - self.request.args = {} - self.request.requestHeaders.getRawHeaders = mock_getRawHeaders() - self.registration_handler = Mock() self.appservice = Mock(sender="@as:test") @@ -44,39 +40,49 @@ class CreateUserServletTestCase(unittest.TestCase): get_app_service_by_token=Mock(return_value=self.appservice) ) - # do the dance to hook things up to the hs global - handlers = Mock( - registration_handler=self.registration_handler, + handlers = Mock(registration_handler=self.registration_handler) + self.clock = MemoryReactorClock() + self.hs_clock = Clock(self.clock) + + self.hs = self.hs = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.clock ) - self.hs = Mock() - self.hs.hostname = "superbig~testing~thing.com" self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.get_handlers = Mock(return_value=handlers) - self.servlet = CreateUserRestServlet(self.hs) - @defer.inlineCallbacks def test_POST_createuser_with_valid_user(self): + + res = JsonResource(self.hs) + register_servlets(self.hs, res) + + request_data = json.dumps( + { + "localpart": "someone", + "displayname": "someone interesting", + "duration_seconds": 200, + } + ) + + url = b'/_matrix/client/api/v1/createUser?access_token=i_am_an_app_service' + user_id = "@someone:interesting" token = "my token" - self.request.args = { - "access_token": "i_am_an_app_service" - } - self.request_data = json.dumps({ - "localpart": "someone", - "displayname": "someone interesting", - "duration_seconds": 200 - }) self.registration_handler.get_or_create_user = Mock( return_value=(user_id, token) ) - (code, result) = yield self.servlet.on_POST(self.request) - self.assertEquals(code, 200) + request, channel = make_request(b"POST", url, request_data) + request.render(res) + + # Advance the clock because it waits + self.clock.advance(1) + + self.assertEquals(channel.result["code"], b"200") det_data = { "user_id": user_id, "access_token": token, - "home_server": self.hs.hostname + "home_server": self.hs.hostname, } - self.assertDictContainsSubset(det_data, result) + self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 895dffa095..6b5764095e 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -25,949 +25,772 @@ from twisted.internet import defer import synapse.rest.client.v1.room from synapse.api.constants import Membership +from synapse.http.server import JsonResource from synapse.types import UserID +from synapse.util import Clock -from ....utils import MockHttpResource, setup_test_homeserver -from .utils import RestTestCase +from tests import unittest +from tests.server import ( + ThreadedMemoryReactorClock, + make_request, + render, + setup_test_homeserver, +) -PATH_PREFIX = "/_matrix/client/api/v1" +from .utils import RestHelper +PATH_PREFIX = b"/_matrix/client/api/v1" -class RoomPermissionsTestCase(RestTestCase): - """ Tests room permissions. """ - user_id = "@sid1:red" - rmcreator_id = "@notme:red" - @defer.inlineCallbacks +class RoomBase(unittest.TestCase): + rmcreator_id = None + def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - hs = yield setup_test_homeserver( + self.clock = ThreadedMemoryReactorClock() + self.hs_clock = Clock(self.clock) + + self.hs = setup_test_homeserver( "red", http_client=None, + clock=self.hs_clock, + reactor=self.clock, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) - self.ratelimiter = hs.get_ratelimiter() + self.ratelimiter = self.hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) - hs.get_handlers().federation_handler = Mock() + self.hs.get_federation_handler = Mock(return_value=Mock()) def get_user_by_access_token(token=None, allow_guest=False): return { - "user": UserID.from_string(self.auth_user_id), + "user": UserID.from_string(self.helper.auth_user_id), "token_id": 1, "is_guest": False, } - hs.get_auth().get_user_by_access_token = get_user_by_access_token + + def get_user_by_req(request, allow_guest=False, rights="access"): + return synapse.types.create_requester( + UserID.from_string(self.helper.auth_user_id), 1, False, None + ) + + self.hs.get_auth().get_user_by_req = get_user_by_req + self.hs.get_auth().get_user_by_access_token = get_user_by_access_token + self.hs.get_auth().get_access_token_from_request = Mock(return_value=b"1234") def _insert_client_ip(*args, **kwargs): return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - self.auth_user_id = self.rmcreator_id + self.hs.get_datastore().insert_client_ip = _insert_client_ip - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + self.resource = JsonResource(self.hs) + synapse.rest.client.v1.room.register_servlets(self.hs, self.resource) + self.helper = RestHelper(self.hs, self.resource, self.user_id) - self.auth = hs.get_auth() - # create some rooms under the name rmcreator_id - self.uncreated_rmid = "!aa:test" +class RoomPermissionsTestCase(RoomBase): + """ Tests room permissions. """ - self.created_rmid = yield self.create_room_as(self.rmcreator_id, - is_public=False) + user_id = b"@sid1:red" + rmcreator_id = b"@notme:red" + + def setUp(self): - self.created_public_rmid = yield self.create_room_as(self.rmcreator_id, - is_public=True) + super(RoomPermissionsTestCase, self).setUp() + + self.helper.auth_user_id = self.rmcreator_id + # create some rooms under the name rmcreator_id + self.uncreated_rmid = "!aa:test" + self.created_rmid = self.helper.create_room_as( + self.rmcreator_id, is_public=False + ) + self.created_public_rmid = self.helper.create_room_as( + self.rmcreator_id, is_public=True + ) # send a message in one of the rooms self.created_rmid_msg_path = ( - "/rooms/%s/send/m.room.message/a1" % (self.created_rmid) - ) - (code, response) = yield self.mock_resource.trigger( - "PUT", + "rooms/%s/send/m.room.message/a1" % (self.created_rmid) + ).encode('ascii') + request, channel = make_request( + b"PUT", self.created_rmid_msg_path, - '{"msgtype":"m.text","body":"test msg"}' + b'{"msgtype":"m.text","body":"test msg"}', ) - self.assertEquals(200, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(channel.result["code"], b"200", channel.result) # set topic for public room - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/rooms/%s/state/m.room.topic" % self.created_public_rmid, - '{"topic":"Public Room Topic"}' + request, channel = make_request( + b"PUT", + ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'), + b'{"topic":"Public Room Topic"}', ) - self.assertEquals(200, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(channel.result["code"], b"200", channel.result) # auth as user_id now - self.auth_user_id = self.user_id - - def tearDown(self): - pass + self.helper.auth_user_id = self.user_id - @defer.inlineCallbacks def test_send_message(self): - msg_content = '{"msgtype":"m.text","body":"hello"}' - send_msg_path = ( - "/rooms/%s/send/m.room.message/mid1" % (self.created_rmid,) - ) + msg_content = b'{"msgtype":"m.text","body":"hello"}' + + seq = iter(range(100)) + + def send_msg_path(): + return b"/rooms/%s/send/m.room.message/mid%s" % ( + self.created_rmid, + str(next(seq)).encode('ascii'), + ) # send message in uncreated room, expect 403 - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), - msg_content + request, channel = make_request( + b"PUT", + b"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), + msg_content, ) - self.assertEquals(403, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 - (code, response) = yield self.mock_resource.trigger( - "PUT", - send_msg_path, - msg_content - ) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"PUT", send_msg_path(), msg_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # send message in created room and invited, expect 403 - yield self.invite( - room=self.created_rmid, - src=self.rmcreator_id, - targ=self.user_id - ) - (code, response) = yield self.mock_resource.trigger( - "PUT", - send_msg_path, - msg_content + self.helper.invite( + room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"PUT", send_msg_path(), msg_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # send message in created room and joined, expect 200 - yield self.join(room=self.created_rmid, user=self.user_id) - (code, response) = yield self.mock_resource.trigger( - "PUT", - send_msg_path, - msg_content - ) - self.assertEquals(200, code, msg=str(response)) + self.helper.join(room=self.created_rmid, user=self.user_id) + request, channel = make_request(b"PUT", send_msg_path(), msg_content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) # send message in created room and left, expect 403 - yield self.leave(room=self.created_rmid, user=self.user_id) - (code, response) = yield self.mock_resource.trigger( - "PUT", - send_msg_path, - msg_content - ) - self.assertEquals(403, code, msg=str(response)) + self.helper.leave(room=self.created_rmid, user=self.user_id) + request, channel = make_request(b"PUT", send_msg_path(), msg_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_topic_perms(self): - topic_content = '{"topic":"My Topic Name"}' - topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid + topic_content = b'{"topic":"My Topic Name"}' + topic_path = b"/rooms/%s/state/m.room.topic" % self.created_rmid # set/get topic in uncreated room, expect 403 - (code, response) = yield self.mock_resource.trigger( - "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, - topic_content + request, channel = make_request( + b"PUT", b"/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - self.assertEquals(403, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/state/m.room.topic" % self.uncreated_rmid + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = make_request( + b"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - self.assertEquals(403, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 - (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content - ) - self.assertEquals(403, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger_get(topic_path) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"PUT", topic_path, topic_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = make_request(b"GET", topic_path) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 - yield self.invite( + self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content - ) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"PUT", topic_path, topic_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 - (code, response) = yield self.mock_resource.trigger_get(topic_path) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"GET", topic_path) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 - yield self.join(room=self.created_rmid, user=self.user_id) + self.helper.join(room=self.created_rmid, user=self.user_id) # Only room ops can set topic by default - self.auth_user_id = self.rmcreator_id - (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content - ) - self.assertEquals(200, code, msg=str(response)) - self.auth_user_id = self.user_id + self.helper.auth_user_id = self.rmcreator_id + request, channel = make_request(b"PUT", topic_path, topic_content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.helper.auth_user_id = self.user_id - (code, response) = yield self.mock_resource.trigger_get(topic_path) - self.assertEquals(200, code, msg=str(response)) - self.assert_dict(json.loads(topic_content), response) + request, channel = make_request(b"GET", topic_path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assert_dict(json.loads(topic_content), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 - yield self.leave(room=self.created_rmid, user=self.user_id) - (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content - ) - self.assertEquals(403, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger_get(topic_path) - self.assertEquals(200, code, msg=str(response)) + self.helper.leave(room=self.created_rmid, user=self.user_id) + request, channel = make_request(b"PUT", topic_path, topic_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = make_request(b"GET", topic_path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/state/m.room.topic" % self.created_public_rmid + request, channel = make_request( + b"GET", b"/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - self.assertEquals(403, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/rooms/%s/state/m.room.topic" % self.created_public_rmid, - topic_content + request, channel = make_request( + b"PUT", + b"/rooms/%s/state/m.room.topic" % self.created_public_rmid, + topic_content, ) - self.assertEquals(403, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def _test_get_membership(self, room=None, members=[], expect_code=None): for member in members: - path = "/rooms/%s/state/m.room.member/%s" % (room, member) - (code, response) = yield self.mock_resource.trigger_get(path) - self.assertEquals(expect_code, code) + path = b"/rooms/%s/state/m.room.member/%s" % (room, member) + request, channel = make_request(b"GET", path) + render(request, self.resource, self.clock) + self.assertEquals(expect_code, int(channel.result["code"])) - @defer.inlineCallbacks def test_membership_basic_room_perms(self): # === room does not exist === room = self.uncreated_rmid # get membership of self, get membership of other, uncreated room # expect all 403s - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=403) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) # trying to invite people to this room should 403 - yield self.invite(room=room, src=self.user_id, targ=self.rmcreator_id, - expect_code=403) + self.helper.invite( + room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 + ) # set [invite/join/left] of self, set [invite/join/left] of other, # expect all 404s because room doesn't exist on any server for usr in [self.user_id, self.rmcreator_id]: - yield self.join(room=room, user=usr, expect_code=404) - yield self.leave(room=room, user=usr, expect_code=404) + self.helper.join(room=room, user=usr, expect_code=404) + self.helper.leave(room=room, user=usr, expect_code=404) - @defer.inlineCallbacks def test_membership_private_room_perms(self): room = self.created_rmid # get membership of self, get membership of other, private room + invite # expect all 403s - yield self.invite(room=room, src=self.rmcreator_id, - targ=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=403) + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) # get membership of self, get membership of other, private room + joined # expect all 200s - yield self.join(room=room, user=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=200) + self.helper.join(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) # get membership of self, get membership of other, private room + left # expect all 200s - yield self.leave(room=room, user=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=200) + self.helper.leave(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) - @defer.inlineCallbacks def test_membership_public_room_perms(self): room = self.created_public_rmid # get membership of self, get membership of other, public room + invite # expect 403 - yield self.invite(room=room, src=self.rmcreator_id, - targ=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=403) + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) # get membership of self, get membership of other, public room + joined # expect all 200s - yield self.join(room=room, user=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=200) + self.helper.join(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) # get membership of self, get membership of other, public room + left # expect 200. - yield self.leave(room=room, user=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=200) + self.helper.leave(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) - @defer.inlineCallbacks def test_invited_permissions(self): room = self.created_rmid - yield self.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) # set [invite/join/left] of other user, expect 403s - yield self.invite(room=room, src=self.user_id, targ=self.rmcreator_id, - expect_code=403) - yield self.change_membership(room=room, src=self.user_id, - targ=self.rmcreator_id, - membership=Membership.JOIN, - expect_code=403) - yield self.change_membership(room=room, src=self.user_id, - targ=self.rmcreator_id, - membership=Membership.LEAVE, - expect_code=403) - - @defer.inlineCallbacks + self.helper.invite( + room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 + ) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.JOIN, + expect_code=403, + ) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.LEAVE, + expect_code=403, + ) + def test_joined_permissions(self): room = self.created_rmid - yield self.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - yield self.join(room=room, user=self.user_id) + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.join(room=room, user=self.user_id) # set invited of self, expect 403 - yield self.invite(room=room, src=self.user_id, targ=self.user_id, - expect_code=403) + self.helper.invite( + room=room, src=self.user_id, targ=self.user_id, expect_code=403 + ) # set joined of self, expect 200 (NOOP) - yield self.join(room=room, user=self.user_id) + self.helper.join(room=room, user=self.user_id) other = "@burgundy:red" # set invited of other, expect 200 - yield self.invite(room=room, src=self.user_id, targ=other, - expect_code=200) + self.helper.invite(room=room, src=self.user_id, targ=other, expect_code=200) # set joined of other, expect 403 - yield self.change_membership(room=room, src=self.user_id, - targ=other, - membership=Membership.JOIN, - expect_code=403) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=other, + membership=Membership.JOIN, + expect_code=403, + ) # set left of other, expect 403 - yield self.change_membership(room=room, src=self.user_id, - targ=other, - membership=Membership.LEAVE, - expect_code=403) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=other, + membership=Membership.LEAVE, + expect_code=403, + ) # set left of self, expect 200 - yield self.leave(room=room, user=self.user_id) + self.helper.leave(room=room, user=self.user_id) - @defer.inlineCallbacks def test_leave_permissions(self): room = self.created_rmid - yield self.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - yield self.join(room=room, user=self.user_id) - yield self.leave(room=room, user=self.user_id) + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.join(room=room, user=self.user_id) + self.helper.leave(room=room, user=self.user_id) # set [invite/join/left] of self, set [invite/join/left] of other, # expect all 403s for usr in [self.user_id, self.rmcreator_id]: - yield self.change_membership( + self.helper.change_membership( room=room, src=self.user_id, targ=usr, membership=Membership.INVITE, - expect_code=403 + expect_code=403, ) - yield self.change_membership( + self.helper.change_membership( room=room, src=self.user_id, targ=usr, membership=Membership.JOIN, - expect_code=403 + expect_code=403, ) # It is always valid to LEAVE if you've already left (currently.) - yield self.change_membership( + self.helper.change_membership( room=room, src=self.user_id, targ=self.rmcreator_id, membership=Membership.LEAVE, - expect_code=403 + expect_code=403, ) -class RoomsMemberListTestCase(RestTestCase): +class RoomsMemberListTestCase(RoomBase): """ Tests /rooms/$room_id/members/list REST events.""" - user_id = "@sid1:red" - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - - hs = yield setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() - - self.auth_user_id = self.user_id - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + user_id = b"@sid1:red" - def tearDown(self): - pass - - @defer.inlineCallbacks def test_get_member_list(self): - room_id = yield self.create_room_as(self.user_id) - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/members" % room_id - ) - self.assertEquals(200, code, msg=str(response)) + room_id = self.helper.create_room_as(self.user_id) + request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_get_member_list_no_room(self): - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/roomdoesnotexist/members" - ) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"GET", b"/rooms/roomdoesnotexist/members") + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_get_member_list_no_permission(self): - room_id = yield self.create_room_as("@some_other_guy:red") - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/members" % room_id - ) - self.assertEquals(403, code, msg=str(response)) + room_id = self.helper.create_room_as(b"@some_other_guy:red") + request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_get_member_list_mixed_memberships(self): - room_creator = "@some_other_guy:red" - room_id = yield self.create_room_as(room_creator) - room_path = "/rooms/%s/members" % room_id - yield self.invite(room=room_id, src=room_creator, - targ=self.user_id) + room_creator = b"@some_other_guy:red" + room_id = self.helper.create_room_as(room_creator) + room_path = b"/rooms/%s/members" % room_id + self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. - (code, response) = yield self.mock_resource.trigger_get(room_path) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"GET", room_path) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - yield self.join(room=room_id, user=self.user_id) + self.helper.join(room=room_id, user=self.user_id) # can see list now joined - (code, response) = yield self.mock_resource.trigger_get(room_path) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"GET", room_path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - yield self.leave(room=room_id, user=self.user_id) + self.helper.leave(room=room_id, user=self.user_id) # can see old list once left - (code, response) = yield self.mock_resource.trigger_get(room_path) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"GET", room_path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) -class RoomsCreateTestCase(RestTestCase): +class RoomsCreateTestCase(RoomBase): """ Tests /rooms and /rooms/$room_id REST events. """ - user_id = "@sid1:red" - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id - hs = yield setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + user_id = b"@sid1:red" - @defer.inlineCallbacks def test_post_room_no_keys(self): # POST with no config keys, expect new room id - (code, response) = yield self.mock_resource.trigger("POST", - "/createRoom", - "{}") - self.assertEquals(200, code, response) - self.assertTrue("room_id" in response) + request, channel = make_request(b"POST", b"/createRoom", b"{}") + + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), channel.result) + self.assertTrue("room_id" in channel.json_body) - @defer.inlineCallbacks def test_post_room_visibility_key(self): # POST with visibility config key, expect new room id - (code, response) = yield self.mock_resource.trigger( - "POST", - "/createRoom", - '{"visibility":"private"}') - self.assertEquals(200, code) - self.assertTrue("room_id" in response) - - @defer.inlineCallbacks + request, channel = make_request( + b"POST", b"/createRoom", b'{"visibility":"private"}' + ) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) + self.assertTrue("room_id" in channel.json_body) + def test_post_room_custom_key(self): # POST with custom config keys, expect new room id - (code, response) = yield self.mock_resource.trigger( - "POST", - "/createRoom", - '{"custom":"stuff"}') - self.assertEquals(200, code) - self.assertTrue("room_id" in response) - - @defer.inlineCallbacks + request, channel = make_request(b"POST", b"/createRoom", b'{"custom":"stuff"}') + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) + self.assertTrue("room_id" in channel.json_body) + def test_post_room_known_and_unknown_keys(self): # POST with custom + known config keys, expect new room id - (code, response) = yield self.mock_resource.trigger( - "POST", - "/createRoom", - '{"visibility":"private","custom":"things"}') - self.assertEquals(200, code) - self.assertTrue("room_id" in response) - - @defer.inlineCallbacks + request, channel = make_request( + b"POST", b"/createRoom", b'{"visibility":"private","custom":"things"}' + ) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) + self.assertTrue("room_id" in channel.json_body) + def test_post_room_invalid_content(self): # POST with invalid content / paths, expect 400 - (code, response) = yield self.mock_resource.trigger( - "POST", - "/createRoom", - '{"visibili') - self.assertEquals(400, code) + request, channel = make_request(b"POST", b"/createRoom", b'{"visibili') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"])) - (code, response) = yield self.mock_resource.trigger( - "POST", - "/createRoom", - '["hello"]') - self.assertEquals(400, code) + request, channel = make_request(b"POST", b"/createRoom", b'["hello"]') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"])) -class RoomTopicTestCase(RestTestCase): +class RoomTopicTestCase(RoomBase): """ Tests /rooms/$room_id/topic REST events. """ - user_id = "@sid1:red" - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id + user_id = b"@sid1:red" - hs = yield setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip + def setUp(self): - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + super(RoomTopicTestCase, self).setUp() # create the room - self.room_id = yield self.create_room_as(self.user_id) - self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) - - def tearDown(self): - pass + self.room_id = self.helper.create_room_as(self.user_id) + self.path = b"/rooms/%s/state/m.room.topic" % (self.room_id,) - @defer.inlineCallbacks def test_invalid_puts(self): # missing keys or invalid json - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, '{}' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, '{}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, '{"_name":"bob"}' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, '{"_name":"bob"}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, '{"nao' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, '{"nao') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]' + request, channel = make_request( + b"PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]' ) - self.assertEquals(400, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, 'text only' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, 'text only') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, '' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, '') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) # valid key, wrong type content = '{"topic":["Topic name"]}' - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, content - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, content) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_rooms_topic(self): # nothing should be there - (code, response) = yield self.mock_resource.trigger_get(self.path) - self.assertEquals(404, code, msg=str(response)) + request, channel = make_request(b"GET", self.path) + render(request, self.resource, self.clock) + self.assertEquals(404, int(channel.result["code"]), msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, content - ) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) # valid get - (code, response) = yield self.mock_resource.trigger_get(self.path) - self.assertEquals(200, code, msg=str(response)) - self.assert_dict(json.loads(content), response) + request, channel = make_request(b"GET", self.path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assert_dict(json.loads(content), channel.json_body) - @defer.inlineCallbacks def test_rooms_topic_with_extra_keys(self): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, content - ) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) # valid get - (code, response) = yield self.mock_resource.trigger_get(self.path) - self.assertEquals(200, code, msg=str(response)) - self.assert_dict(json.loads(content), response) + request, channel = make_request(b"GET", self.path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assert_dict(json.loads(content), channel.json_body) -class RoomMemberStateTestCase(RestTestCase): +class RoomMemberStateTestCase(RoomBase): """ Tests /rooms/$room_id/members/$user_id/state REST events. """ - user_id = "@sid1:red" - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id - - hs = yield setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - hs.get_handlers().federation_handler = Mock() + user_id = b"@sid1:red" - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + def setUp(self): - self.room_id = yield self.create_room_as(self.user_id) + super(RoomMemberStateTestCase, self).setUp() + self.room_id = self.helper.create_room_as(self.user_id) def tearDown(self): pass - @defer.inlineCallbacks def test_invalid_puts(self): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json - (code, response) = yield self.mock_resource.trigger("PUT", path, '{}') - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '{"_name":"bob"}' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{"_name":"bob"}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '{"nao' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{"nao') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' + request, channel = make_request( + b"PUT", path, b'[{"_name":"bob"},{"_name":"jill"}]' ) - self.assertEquals(400, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, 'text only' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, 'text only') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) # valid keys, wrong types - content = ('{"membership":["%s","%s","%s"]}' % ( - Membership.INVITE, Membership.JOIN, Membership.LEAVE - )) - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(400, code, msg=str(response)) + content = '{"membership":["%s","%s","%s"]}' % ( + Membership.INVITE, + Membership.JOIN, + Membership.LEAVE, + ) + request, channel = make_request(b"PUT", path, content.encode('ascii')) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_rooms_members_self(self): path = "/rooms/%s/state/m.room.member/%s" % ( - urlparse.quote(self.room_id), self.user_id + urlparse.quote(self.room_id), + self.user_id, ) # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"PUT", path, content.encode('ascii')) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger("GET", path, None) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"GET", path, None) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - expected_response = { - "membership": Membership.JOIN, - } - self.assertEquals(expected_response, response) + expected_response = {"membership": Membership.JOIN} + self.assertEquals(expected_response, channel.json_body) - @defer.inlineCallbacks def test_rooms_members_other(self): self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( - urlparse.quote(self.room_id), self.other_id + urlparse.quote(self.room_id), + self.other_id, ) # valid invite message content = '{"membership":"%s"}' % Membership.INVITE - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"PUT", path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger("GET", path, None) - self.assertEquals(200, code, msg=str(response)) - self.assertEquals(json.loads(content), response) + request, channel = make_request(b"GET", path, None) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(json.loads(content), channel.json_body) - @defer.inlineCallbacks def test_rooms_members_other_custom_keys(self): self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( - urlparse.quote(self.room_id), self.other_id + urlparse.quote(self.room_id), + self.other_id, ) # valid invite message with custom key - content = ('{"membership":"%s","invite_text":"%s"}' % ( - Membership.INVITE, "Join us!" - )) - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(200, code, msg=str(response)) + content = '{"membership":"%s","invite_text":"%s"}' % ( + Membership.INVITE, + "Join us!", + ) + request, channel = make_request(b"PUT", path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger("GET", path, None) - self.assertEquals(200, code, msg=str(response)) - self.assertEquals(json.loads(content), response) + request, channel = make_request(b"GET", path, None) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(json.loads(content), channel.json_body) -class RoomMessagesTestCase(RestTestCase): +class RoomMessagesTestCase(RoomBase): """ Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """ + user_id = "@sid1:red" - @defer.inlineCallbacks def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id - - hs = yield setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() + super(RoomMessagesTestCase, self).setUp() - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip + self.room_id = self.helper.create_room_as(self.user_id) - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) - - self.room_id = yield self.create_room_as(self.user_id) - - def tearDown(self): - pass - - @defer.inlineCallbacks def test_invalid_puts(self): - path = "/rooms/%s/send/m.room.message/mid1" % ( - urlparse.quote(self.room_id)) + path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '{}' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '{"_name":"bob"}' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{"_name":"bob"}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '{"nao' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{"nao') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' + request, channel = make_request( + b"PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' ) - self.assertEquals(400, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, 'text only' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, 'text only') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_rooms_messages_sent(self): - path = "/rooms/%s/send/m.room.message/mid1" % ( - urlparse.quote(self.room_id)) + path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = '{"body":"test","msgtype":{"type":"a"}}' - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, content) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) # custom message types content = '{"body":"test","msgtype":"test.custom.text"}' - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(200, code, msg=str(response)) - -# (code, response) = yield self.mock_resource.trigger("GET", path, None) -# self.assertEquals(200, code, msg=str(response)) -# self.assert_dict(json.loads(content), response) + request, channel = make_request(b"PUT", path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) # m.text message type - path = "/rooms/%s/send/m.room.message/mid2" % ( - urlparse.quote(self.room_id)) + path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = '{"body":"test2","msgtype":"m.text"}' - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"PUT", path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) -class RoomInitialSyncTestCase(RestTestCase): +class RoomInitialSyncTestCase(RoomBase): """ Tests /rooms/$room_id/initialSync. """ + user_id = "@sid1:red" - @defer.inlineCallbacks def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id - - hs = yield setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=[ - "send_message", - ]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + super(RoomInitialSyncTestCase, self).setUp() # create the room - self.room_id = yield self.create_room_as(self.user_id) + self.room_id = self.helper.create_room_as(self.user_id) - @defer.inlineCallbacks def test_initial_sync(self): - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/initialSync" % self.room_id - ) - self.assertEquals(200, code) + request, channel = make_request(b"GET", "/rooms/%s/initialSync" % self.room_id) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) - self.assertEquals(self.room_id, response["room_id"]) - self.assertEquals("join", response["membership"]) + self.assertEquals(self.room_id, channel.json_body["room_id"]) + self.assertEquals("join", channel.json_body["membership"]) # Room state is easier to assert on if we unpack it into a dict state = {} - for event in response["state"]: + for event in channel.json_body["state"]: if "state_key" not in event: continue t = event["type"] @@ -977,75 +800,48 @@ class RoomInitialSyncTestCase(RestTestCase): self.assertTrue("m.room.create" in state) - self.assertTrue("messages" in response) - self.assertTrue("chunk" in response["messages"]) - self.assertTrue("end" in response["messages"]) + self.assertTrue("messages" in channel.json_body) + self.assertTrue("chunk" in channel.json_body["messages"]) + self.assertTrue("end" in channel.json_body["messages"]) - self.assertTrue("presence" in response) + self.assertTrue("presence" in channel.json_body) presence_by_user = { - e["content"]["user_id"]: e for e in response["presence"] + e["content"]["user_id"]: e for e in channel.json_body["presence"] } self.assertTrue(self.user_id in presence_by_user) self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) -class RoomMessageListTestCase(RestTestCase): +class RoomMessageListTestCase(RoomBase): """ Tests /rooms/$room_id/messages REST events. """ + user_id = "@sid1:red" - @defer.inlineCallbacks def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id - - hs = yield setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_auth().get_user_by_access_token = get_user_by_access_token + super(RoomMessageListTestCase, self).setUp() + self.room_id = self.helper.create_room_as(self.user_id) - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) - - self.room_id = yield self.create_room_as(self.user_id) - - @defer.inlineCallbacks def test_topo_token_is_accepted(self): token = "t1-0_0_0_0_0_0_0_0_0" - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/messages?access_token=x&from=%s" % - (self.room_id, token)) - self.assertEquals(200, code) - self.assertTrue("start" in response) - self.assertEquals(token, response['start']) - self.assertTrue("chunk" in response) - self.assertTrue("end" in response) - - @defer.inlineCallbacks + request, channel = make_request( + b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + ) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) + self.assertTrue("start" in channel.json_body) + self.assertEquals(token, channel.json_body['start']) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("end" in channel.json_body) + def test_stream_token_is_accepted_for_fwd_pagianation(self): token = "s0_0_0_0_0_0_0_0_0" - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/messages?access_token=x&from=%s" % - (self.room_id, token)) - self.assertEquals(200, code) - self.assertTrue("start" in response) - self.assertEquals(token, response['start']) - self.assertTrue("chunk" in response) - self.assertTrue("end" in response) + request, channel = make_request( + b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + ) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) + self.assertTrue("start" in channel.json_body) + self.assertEquals(token, channel.json_body['start']) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("end" in channel.json_body) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 54d7ba380d..41de8e0762 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -16,13 +16,14 @@ import json import time -# twisted imports +import attr + from twisted.internet import defer from synapse.api.constants import Membership -# trial imports from tests import unittest +from tests.server import make_request, wait_until_result class RestTestCase(unittest.TestCase): @@ -133,3 +134,113 @@ class RestTestCase(unittest.TestCase): for key in required: self.assertEquals(required[key], actual[key], msg="%s mismatch. %s" % (key, actual)) + + +@attr.s +class RestHelper(object): + """Contains extra helper functions to quickly and clearly perform a given + REST action, which isn't the focus of the test. + """ + + hs = attr.ib() + resource = attr.ib() + auth_user_id = attr.ib() + + def create_room_as(self, room_creator, is_public=True, tok=None): + temp_id = self.auth_user_id + self.auth_user_id = room_creator + path = b"/_matrix/client/r0/createRoom" + content = {} + if not is_public: + content["visibility"] = "private" + if tok: + path = path + b"?access_token=%s" % tok.encode('ascii') + + request, channel = make_request(b"POST", path, json.dumps(content).encode('utf8')) + request.render(self.resource) + wait_until_result(self.hs.get_reactor(), channel) + + assert channel.result["code"] == b"200", channel.result + self.auth_user_id = temp_id + return channel.json_body["room_id"] + + def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=src, + targ=targ, + tok=tok, + membership=Membership.INVITE, + expect_code=expect_code, + ) + + def join(self, room=None, user=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=user, + targ=user, + tok=tok, + membership=Membership.JOIN, + expect_code=expect_code, + ) + + def leave(self, room=None, user=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=user, + targ=user, + tok=tok, + membership=Membership.LEAVE, + expect_code=expect_code, + ) + + def change_membership(self, room, src, targ, membership, tok=None, expect_code=200): + temp_id = self.auth_user_id + self.auth_user_id = src + + path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ) + if tok: + path = path + "?access_token=%s" % tok + + data = {"membership": membership} + + request, channel = make_request( + b"PUT", path.encode('ascii'), json.dumps(data).encode('utf8') + ) + + request.render(self.resource) + wait_until_result(self.hs.get_reactor(), channel) + + assert int(channel.result["code"]) == expect_code, ( + "Expected: %d, got: %d, resp: %r" + % (expect_code, int(channel.result["code"]), channel.result["body"]) + ) + + self.auth_user_id = temp_id + + @defer.inlineCallbacks + def register(self, user_id): + (code, response) = yield self.mock_resource.trigger( + "POST", + "/_matrix/client/r0/register", + json.dumps( + {"user": user_id, "password": "test", "type": "m.login.password"} + ), + ) + self.assertEquals(200, code) + defer.returnValue(response) + + @defer.inlineCallbacks + def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + if body is None: + body = "body_text_here" + + path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) + content = '{"msgtype":"m.text","body":"%s"}' % body + if tok: + path = path + "?access_token=%s" % tok + + (code, response) = yield self.mock_resource.trigger("PUT", path, content) + self.assertEquals(expect_code, code, msg=str(response)) diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index f18a8a6027..e69de29bb2 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -1,61 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from mock import Mock - -from twisted.internet import defer - -from synapse.types import UserID - -from tests import unittest - -from ....utils import MockHttpResource, setup_test_homeserver - -PATH_PREFIX = "/_matrix/client/v2_alpha" - - -class V2AlphaRestTestCase(unittest.TestCase): - # Consumer must define - # USER_ID = - # TO_REGISTER = [] - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - - hs = yield setup_test_homeserver( - datastore=self.make_datastore_mock(), - http_client=None, - resource_for_client=self.mock_resource, - resource_for_federation=self.mock_resource, - ) - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.USER_ID), - "token_id": 1, - "is_guest": False, - } - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - for r in self.TO_REGISTER: - r.register_servlets(hs, self.mock_resource) - - def make_datastore_mock(self): - store = Mock(spec=[ - "insert_client_ip", - ]) - store.get_app_service_by_token = Mock(return_value=None) - return store diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index bb0b2f94ea..5ea9cc825f 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -13,35 +13,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - import synapse.types from synapse.api.errors import Codes +from synapse.http.server import JsonResource from synapse.rest.client.v2_alpha import filter from synapse.types import UserID +from synapse.util import Clock from tests import unittest - -from ....utils import MockHttpResource, setup_test_homeserver +from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock +from tests.server import make_request, setup_test_homeserver, wait_until_result PATH_PREFIX = "/_matrix/client/v2_alpha" class FilterTestCase(unittest.TestCase): - USER_ID = "@apple:test" + USER_ID = b"@apple:test" EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} - EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}' + EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' TO_REGISTER = [filter] - @defer.inlineCallbacks def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) + self.clock = MemoryReactorClock() + self.hs_clock = Clock(self.clock) - self.hs = yield setup_test_homeserver( - http_client=None, - resource_for_client=self.mock_resource, - resource_for_federation=self.mock_resource, + self.hs = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.clock ) self.auth = self.hs.get_auth() @@ -55,82 +53,103 @@ class FilterTestCase(unittest.TestCase): def get_user_by_req(request, allow_guest=False, rights="access"): return synapse.types.create_requester( - UserID.from_string(self.USER_ID), 1, False, None) + UserID.from_string(self.USER_ID), 1, False, None + ) self.auth.get_user_by_access_token = get_user_by_access_token self.auth.get_user_by_req = get_user_by_req self.store = self.hs.get_datastore() self.filtering = self.hs.get_filtering() + self.resource = JsonResource(self.hs) for r in self.TO_REGISTER: - r.register_servlets(self.hs, self.mock_resource) + r.register_servlets(self.hs, self.resource) - @defer.inlineCallbacks def test_add_filter(self): - (code, response) = yield self.mock_resource.trigger( - "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON - ) - self.assertEquals(200, code) - self.assertEquals({"filter_id": "0"}, response) - filter = yield self.store.get_user_filter( - user_localpart='apple', - filter_id=0, + request, channel = make_request( + b"POST", + b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), + self.EXAMPLE_FILTER_JSON, ) - self.assertEquals(filter, self.EXAMPLE_FILTER) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.json_body, {"filter_id": "0"}) + filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) + self.clock.advance(0) + self.assertEquals(filter.result, self.EXAMPLE_FILTER) - @defer.inlineCallbacks def test_add_filter_for_other_user(self): - (code, response) = yield self.mock_resource.trigger( - "POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON + request, channel = make_request( + b"POST", + b"/_matrix/client/r0/user/%s/filter" % (b"@watermelon:test"), + self.EXAMPLE_FILTER_JSON, ) - self.assertEquals(403, code) - self.assertEquals(response['errcode'], Codes.FORBIDDEN) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"403") + self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) - @defer.inlineCallbacks def test_add_filter_non_local_user(self): _is_mine = self.hs.is_mine self.hs.is_mine = lambda target_user: False - (code, response) = yield self.mock_resource.trigger( - "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON + request, channel = make_request( + b"POST", + b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), + self.EXAMPLE_FILTER_JSON, ) + request.render(self.resource) + wait_until_result(self.clock, channel) + self.hs.is_mine = _is_mine - self.assertEquals(403, code) - self.assertEquals(response['errcode'], Codes.FORBIDDEN) + self.assertEqual(channel.result["code"], b"403") + self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) - @defer.inlineCallbacks def test_get_filter(self): - filter_id = yield self.filtering.add_user_filter( - user_localpart='apple', - user_filter=self.EXAMPLE_FILTER + filter_id = self.filtering.add_user_filter( + user_localpart="apple", user_filter=self.EXAMPLE_FILTER ) - (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/%s" % (self.USER_ID, filter_id) + self.clock.advance(1) + filter_id = filter_id.result + request, channel = make_request( + b"GET", b"/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id) ) - self.assertEquals(200, code) - self.assertEquals(self.EXAMPLE_FILTER, response) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"200") + self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) - @defer.inlineCallbacks def test_get_filter_non_existant(self): - (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/12382148321" % (self.USER_ID) + request, channel = make_request( + b"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID) ) - self.assertEquals(400, code) - self.assertEquals(response['errcode'], Codes.NOT_FOUND) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"400") + self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) # Currently invalid params do not have an appropriate errcode # in errors.py - @defer.inlineCallbacks def test_get_filter_invalid_id(self): - (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/foobar" % (self.USER_ID) + request, channel = make_request( + b"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID) ) - self.assertEquals(400, code) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"400") # No ID also returns an invalid_id error - @defer.inlineCallbacks def test_get_filter_no_id(self): - (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/" % (self.USER_ID) + request, channel = make_request( + b"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID) ) - self.assertEquals(400, code) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"400") diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 9b57a56070..e004d8fc73 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -2,165 +2,192 @@ import json from mock import Mock -from twisted.internet import defer from twisted.python import failure +from twisted.test.proto_helpers import MemoryReactorClock -from synapse.api.errors import InteractiveAuthIncompleteError, SynapseError -from synapse.rest.client.v2_alpha.register import RegisterRestServlet +from synapse.api.errors import InteractiveAuthIncompleteError +from synapse.http.server import JsonResource +from synapse.rest.client.v2_alpha.register import register_servlets +from synapse.util import Clock from tests import unittest -from tests.utils import mock_getRawHeaders +from tests.server import make_request, setup_test_homeserver, wait_until_result class RegisterRestServletTestCase(unittest.TestCase): - def setUp(self): - # do the dance to hook up request data to self.request_data - self.request_data = "" - self.request = Mock( - content=Mock(read=Mock(side_effect=lambda: self.request_data)), - path='/_matrix/api/v2_alpha/register' - ) - self.request.args = {} - self.request.requestHeaders.getRawHeaders = mock_getRawHeaders() + + self.clock = MemoryReactorClock() + self.hs_clock = Clock(self.clock) + self.url = b"/_matrix/client/r0/register" self.appservice = None - self.auth = Mock(get_appservice_by_req=Mock( - side_effect=lambda x: self.appservice) + self.auth = Mock( + get_appservice_by_req=Mock(side_effect=lambda x: self.appservice) ) self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None)) self.auth_handler = Mock( check_auth=Mock(side_effect=lambda x, y, z: self.auth_result), - get_session_data=Mock(return_value=None) + get_session_data=Mock(return_value=None), ) self.registration_handler = Mock() self.identity_handler = Mock() self.login_handler = Mock() self.device_handler = Mock() + self.device_handler.check_device_registered = Mock(return_value="FAKE") + + self.datastore = Mock(return_value=Mock()) + self.datastore.get_current_state_deltas = Mock(return_value=[]) # do the dance to hook it up to the hs global self.handlers = Mock( registration_handler=self.registration_handler, identity_handler=self.identity_handler, - login_handler=self.login_handler + login_handler=self.login_handler, + ) + self.hs = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.clock ) - self.hs = Mock() - self.hs.hostname = "superbig~testing~thing.com" self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.get_device_handler = Mock(return_value=self.device_handler) + self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.config.enable_registration = True self.hs.config.registrations_require_3pid = [] self.hs.config.auto_join_rooms = [] - # init the thing we're testing - self.servlet = RegisterRestServlet(self.hs) + self.resource = JsonResource(self.hs) + register_servlets(self.hs, self.resource) - @defer.inlineCallbacks def test_POST_appservice_registration_valid(self): user_id = "@kermit:muppet" token = "kermits_access_token" - self.request.args = { - "access_token": "i_am_an_app_service" - } - self.request_data = json.dumps({ - "username": "kermit" - }) - self.appservice = { - "id": "1234" - } - self.registration_handler.appservice_register = Mock( - return_value=user_id - ) - self.auth_handler.get_access_token_for_user_id = Mock( - return_value=token + self.appservice = {"id": "1234"} + self.registration_handler.appservice_register = Mock(return_value=user_id) + self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) + request_data = json.dumps({"username": "kermit"}) + + request, channel = make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) + request.render(self.resource) + wait_until_result(self.clock, channel) - (code, result) = yield self.servlet.on_POST(self.request) - self.assertEquals(code, 200) + self.assertEquals(channel.result["code"], b"200", channel.result) det_data = { "user_id": user_id, "access_token": token, - "home_server": self.hs.hostname + "home_server": self.hs.hostname, } - self.assertDictContainsSubset(det_data, result) + self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) - @defer.inlineCallbacks def test_POST_appservice_registration_invalid(self): - self.request.args = { - "access_token": "i_am_an_app_service" - } - - self.request_data = json.dumps({ - "username": "kermit" - }) self.appservice = None # no application service exists - result = yield self.servlet.on_POST(self.request) - self.assertEquals(result, (401, None)) + request_data = json.dumps({"username": "kermit"}) + request, channel = make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEquals(channel.result["code"], b"401", channel.result) def test_POST_bad_password(self): - self.request_data = json.dumps({ - "username": "kermit", - "password": 666 - }) - d = self.servlet.on_POST(self.request) - return self.assertFailure(d, SynapseError) + request_data = json.dumps({"username": "kermit", "password": 666}) + request, channel = make_request(b"POST", self.url, request_data) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEquals( + json.loads(channel.result["body"])["error"], "Invalid password" + ) def test_POST_bad_username(self): - self.request_data = json.dumps({ - "username": 777, - "password": "monkey" - }) - d = self.servlet.on_POST(self.request) - return self.assertFailure(d, SynapseError) - - @defer.inlineCallbacks + request_data = json.dumps({"username": 777, "password": "monkey"}) + request, channel = make_request(b"POST", self.url, request_data) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEquals( + json.loads(channel.result["body"])["error"], "Invalid username" + ) + def test_POST_user_valid(self): user_id = "@kermit:muppet" token = "kermits_access_token" device_id = "frogfone" - self.request_data = json.dumps({ - "username": "kermit", - "password": "monkey", - "device_id": device_id, - }) + request_data = json.dumps( + {"username": "kermit", "password": "monkey", "device_id": device_id} + ) self.registration_handler.check_username = Mock(return_value=True) - self.auth_result = (None, { - "username": "kermit", - "password": "monkey" - }, None) + self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) self.registration_handler.register = Mock(return_value=(user_id, None)) - self.auth_handler.get_access_token_for_user_id = Mock( - return_value=token - ) - self.device_handler.check_device_registered = \ - Mock(return_value=device_id) + self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) + self.device_handler.check_device_registered = Mock(return_value=device_id) + + request, channel = make_request(b"POST", self.url, request_data) + request.render(self.resource) + wait_until_result(self.clock, channel) - (code, result) = yield self.servlet.on_POST(self.request) - self.assertEquals(code, 200) det_data = { "user_id": user_id, "access_token": token, "home_server": self.hs.hostname, "device_id": device_id, } - self.assertDictContainsSubset(det_data, result) + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) self.auth_handler.get_login_tuple_for_user_id( - user_id, device_id=device_id, initial_device_display_name=None) + user_id, device_id=device_id, initial_device_display_name=None + ) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False - self.request_data = json.dumps({ - "username": "kermit", - "password": "monkey" - }) + request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.registration_handler.check_username = Mock(return_value=True) - self.auth_result = (None, { - "username": "kermit", - "password": "monkey" - }, None) + self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) self.registration_handler.register = Mock(return_value=("@user:id", "t")) - d = self.servlet.on_POST(self.request) - return self.assertFailure(d, SynapseError) + + request, channel = make_request(b"POST", self.url, request_data) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals( + json.loads(channel.result["body"])["error"], + "Registration has been disabled", + ) + + def test_POST_guest_registration(self): + user_id = "a@b" + self.hs.config.macaroon_secret_key = "test" + self.hs.config.allow_guest_access = True + self.registration_handler.register = Mock(return_value=(user_id, None)) + + request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") + request.render(self.resource) + wait_until_result(self.clock, channel) + + det_data = { + "user_id": user_id, + "home_server": self.hs.hostname, + "device_id": "guest_device", + } + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) + + def test_POST_disabled_guest_registration(self): + self.hs.config.allow_guest_access = False + + request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals( + json.loads(channel.result["body"])["error"], "Guest access is disabled" + ) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py new file mode 100644 index 0000000000..704cf97a40 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse.types +from synapse.http.server import JsonResource +from synapse.rest.client.v2_alpha import sync +from synapse.types import UserID +from synapse.util import Clock + +from tests import unittest +from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock +from tests.server import make_request, setup_test_homeserver, wait_until_result + +PATH_PREFIX = "/_matrix/client/v2_alpha" + + +class FilterTestCase(unittest.TestCase): + + USER_ID = b"@apple:test" + TO_REGISTER = [sync] + + def setUp(self): + self.clock = MemoryReactorClock() + self.hs_clock = Clock(self.clock) + + self.hs = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.clock + ) + + self.auth = self.hs.get_auth() + + def get_user_by_access_token(token=None, allow_guest=False): + return { + "user": UserID.from_string(self.USER_ID), + "token_id": 1, + "is_guest": False, + } + + def get_user_by_req(request, allow_guest=False, rights="access"): + return synapse.types.create_requester( + UserID.from_string(self.USER_ID), 1, False, None + ) + + self.auth.get_user_by_access_token = get_user_by_access_token + self.auth.get_user_by_req = get_user_by_req + + self.store = self.hs.get_datastore() + self.filtering = self.hs.get_filtering() + self.resource = JsonResource(self.hs) + + for r in self.TO_REGISTER: + r.register_servlets(self.hs, self.resource) + + def test_sync_argless(self): + request, channel = make_request(b"GET", b"/_matrix/client/r0/sync") + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"200") + self.assertTrue( + set( + [ + "next_batch", + "rooms", + "presence", + "account_data", + "to_device", + "device_lists", + ] + ).issubset(set(channel.json_body.keys())) + ) diff --git a/tests/server.py b/tests/server.py index e93f5a7f94..c611dd6059 100644 --- a/tests/server.py +++ b/tests/server.py @@ -80,6 +80,11 @@ def make_request(method, path, content=b""): content, and return the Request and the Channel underneath. """ + # Decorate it to be the full path + if not path.startswith(b"/_matrix"): + path = b"/_matrix/client/r0/" + path + path = path.replace("//", "/") + if isinstance(content, text_type): content = content.encode('utf8') @@ -110,6 +115,11 @@ def wait_until_result(clock, channel, timeout=100): clock.advance(0.1) +def render(request, resource, clock): + request.render(resource) + wait_until_result(clock, request._channel) + + class ThreadedMemoryReactorClock(MemoryReactorClock): """ A MemoryReactorClock that supports callFromThread. diff --git a/tests/test_server.py b/tests/test_server.py index 4192013f6d..7e063c0290 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -33,9 +33,11 @@ class JsonResourceTests(unittest.TestCase): return (200, kwargs) res = JsonResource(self.homeserver) - res.register_paths("GET", [re.compile("^/foo/(?P[^/]*)$")], _callback) + res.register_paths( + "GET", [re.compile("^/_matrix/foo/(?P[^/]*)$")], _callback + ) - request, channel = make_request(b"GET", b"/foo/%E2%98%83?a=%E2%98%83") + request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83") request.render(res) self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) @@ -51,9 +53,9 @@ class JsonResourceTests(unittest.TestCase): raise Exception("boo") res = JsonResource(self.homeserver) - res.register_paths("GET", [re.compile("^/foo$")], _callback) + res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) - request, channel = make_request(b"GET", b"/foo") + request, channel = make_request(b"GET", b"/_matrix/foo") request.render(res) self.assertEqual(channel.result["code"], b'500') @@ -74,9 +76,9 @@ class JsonResourceTests(unittest.TestCase): return d res = JsonResource(self.homeserver) - res.register_paths("GET", [re.compile("^/foo$")], _callback) + res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) - request, channel = make_request(b"GET", b"/foo") + request, channel = make_request(b"GET", b"/_matrix/foo") request.render(res) # No error has been raised yet @@ -96,9 +98,9 @@ class JsonResourceTests(unittest.TestCase): raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN) res = JsonResource(self.homeserver) - res.register_paths("GET", [re.compile("^/foo$")], _callback) + res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) - request, channel = make_request(b"GET", b"/foo") + request, channel = make_request(b"GET", b"/_matrix/foo") request.render(res) self.assertEqual(channel.result["code"], b'403') @@ -118,9 +120,9 @@ class JsonResourceTests(unittest.TestCase): self.fail("shouldn't ever get here") res = JsonResource(self.homeserver) - res.register_paths("GET", [re.compile("^/foo$")], _callback) + res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) - request, channel = make_request(b"GET", b"/foobar") + request, channel = make_request(b"GET", b"/_matrix/foobar") request.render(res) self.assertEqual(channel.result["code"], b'400') diff --git a/tests/unittest.py b/tests/unittest.py index b25f2db5d5..b15b06726b 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -109,6 +109,17 @@ class TestCase(unittest.TestCase): except AssertionError as e: raise (type(e))(e.message + " for '.%s'" % key) + def assert_dict(self, required, actual): + """Does a partial assert of a dict. + + Args: + required (dict): The keys and value which MUST be in 'actual'. + actual (dict): The test result. Extra keys will not be checked. + """ + for key in required: + self.assertEquals(required[key], actual[key], + msg="%s mismatch. %s" % (key, actual)) + def DEBUG(target): """A decorator to set the .loglevel attribute to logging.DEBUG. -- cgit 1.4.1 From a97c845271f9a89ebdb7186d4c9d04c099bd1beb Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 19 Jul 2018 20:03:33 +1000 Subject: Move v1-only APIs into their own module & isolate deprecated ones (#3460) --- changelog.d/3460.misc | 0 setup.cfg | 1 + synapse/http/client.py | 6 +- synapse/rest/__init__.py | 43 ++- synapse/rest/client/v1/register.py | 436 ----------------------------- synapse/rest/client/v1/room.py | 5 +- synapse/rest/client/v1_only/__init__.py | 3 + synapse/rest/client/v1_only/base.py | 39 +++ synapse/rest/client/v1_only/register.py | 437 ++++++++++++++++++++++++++++++ tests/rest/client/v1/test_events.py | 90 +----- tests/rest/client/v1/test_register.py | 5 +- tests/rest/client/v1/test_rooms.py | 2 +- tests/rest/client/v2_alpha/test_filter.py | 8 +- tests/rest/client/v2_alpha/test_sync.py | 8 +- 14 files changed, 549 insertions(+), 534 deletions(-) create mode 100644 changelog.d/3460.misc delete mode 100644 synapse/rest/client/v1/register.py create mode 100644 synapse/rest/client/v1_only/__init__.py create mode 100644 synapse/rest/client/v1_only/base.py create mode 100644 synapse/rest/client/v1_only/register.py (limited to 'synapse/rest/client') diff --git a/changelog.d/3460.misc b/changelog.d/3460.misc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/setup.cfg b/setup.cfg index 3f84283a38..c2620be6c5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,3 +36,4 @@ known_compat = mock,six known_twisted=twisted,OpenSSL multi_line_output=3 include_trailing_comma=true +combine_as_imports=true diff --git a/synapse/http/client.py b/synapse/http/client.py index d6a0d75b2b..25b6307884 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -26,9 +26,11 @@ from OpenSSL.SSL import VERIFY_NONE from twisted.internet import defer, protocol, reactor, ssl, task from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.web._newclient import ResponseDone -from twisted.web.client import Agent, BrowserLikeRedirectAgent, ContentDecoderAgent -from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer from twisted.web.client import ( + Agent, + BrowserLikeRedirectAgent, + ContentDecoderAgent, + FileBodyProducer as TwistedFileBodyProducer, GzipDecoder, HTTPConnectionPool, PartialDownloadError, diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 75c2a4ec8e..3418f06fd6 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +14,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six import PY3 + from synapse.http.server import JsonResource from synapse.rest.client import versions -from synapse.rest.client.v1 import admin, directory, events, initial_sync -from synapse.rest.client.v1 import login as v1_login -from synapse.rest.client.v1 import logout, presence, profile, push_rule, pusher -from synapse.rest.client.v1 import register as v1_register -from synapse.rest.client.v1 import room, voip +from synapse.rest.client.v1 import ( + admin, + directory, + events, + initial_sync, + login as v1_login, + logout, + presence, + profile, + push_rule, + pusher, + room, + voip, +) from synapse.rest.client.v2_alpha import ( account, account_data, @@ -42,6 +54,11 @@ from synapse.rest.client.v2_alpha import ( user_directory, ) +if not PY3: + from synapse.rest.client.v1_only import ( + register as v1_register, + ) + class ClientRestResource(JsonResource): """A resource for version 1 of the matrix client API.""" @@ -54,14 +71,22 @@ class ClientRestResource(JsonResource): def register_servlets(client_resource, hs): versions.register_servlets(client_resource) - # "v1" - room.register_servlets(hs, client_resource) + if not PY3: + # "v1" (Python 2 only) + v1_register.register_servlets(hs, client_resource) + + # Deprecated in r0 + initial_sync.register_servlets(hs, client_resource) + room.register_deprecated_servlets(hs, client_resource) + + # Partially deprecated in r0 events.register_servlets(hs, client_resource) - v1_register.register_servlets(hs, client_resource) + + # "v1" + "r0" + room.register_servlets(hs, client_resource) v1_login.register_servlets(hs, client_resource) profile.register_servlets(hs, client_resource) presence.register_servlets(hs, client_resource) - initial_sync.register_servlets(hs, client_resource) directory.register_servlets(hs, client_resource) voip.register_servlets(hs, client_resource) admin.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py deleted file mode 100644 index 25a143af8d..0000000000 --- a/synapse/rest/client/v1/register.py +++ /dev/null @@ -1,436 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This module contains REST servlets to do with registration: /register""" -import hmac -import logging -from hashlib import sha1 - -from twisted.internet import defer - -import synapse.util.stringutils as stringutils -from synapse.api.constants import LoginType -from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request -from synapse.types import create_requester - -from .base import ClientV1RestServlet, client_path_patterns - -logger = logging.getLogger(__name__) - - -# We ought to be using hmac.compare_digest() but on older pythons it doesn't -# exist. It's a _really minor_ security flaw to use plain string comparison -# because the timing attack is so obscured by all the other code here it's -# unlikely to make much difference -if hasattr(hmac, "compare_digest"): - compare_digest = hmac.compare_digest -else: - def compare_digest(a, b): - return a == b - - -class RegisterRestServlet(ClientV1RestServlet): - """Handles registration with the home server. - - This servlet is in control of the registration flow; the registration - handler doesn't have a concept of multi-stages or sessions. - """ - - PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super(RegisterRestServlet, self).__init__(hs) - # sessions are stored as: - # self.sessions = { - # "session_id" : { __session_dict__ } - # } - # TODO: persistent storage - self.sessions = {} - self.enable_registration = hs.config.enable_registration - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - self.handlers = hs.get_handlers() - - def on_GET(self, request): - - require_email = 'email' in self.hs.config.registrations_require_3pid - require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid - - flows = [] - if self.hs.config.enable_registration_captcha: - # only support the email-only flow if we don't require MSISDN 3PIDs - if not require_msisdn: - flows.extend([ - { - "type": LoginType.RECAPTCHA, - "stages": [ - LoginType.RECAPTCHA, - LoginType.EMAIL_IDENTITY, - LoginType.PASSWORD - ] - }, - ]) - # only support 3PIDless registration if no 3PIDs are required - if not require_email and not require_msisdn: - flows.extend([ - { - "type": LoginType.RECAPTCHA, - "stages": [LoginType.RECAPTCHA, LoginType.PASSWORD] - } - ]) - else: - # only support the email-only flow if we don't require MSISDN 3PIDs - if require_email or not require_msisdn: - flows.extend([ - { - "type": LoginType.EMAIL_IDENTITY, - "stages": [ - LoginType.EMAIL_IDENTITY, LoginType.PASSWORD - ] - } - ]) - # only support 3PIDless registration if no 3PIDs are required - if not require_email and not require_msisdn: - flows.extend([ - { - "type": LoginType.PASSWORD - } - ]) - return (200, {"flows": flows}) - - @defer.inlineCallbacks - def on_POST(self, request): - register_json = parse_json_object_from_request(request) - - session = (register_json["session"] - if "session" in register_json else None) - login_type = None - assert_params_in_dict(register_json, ["type"]) - - try: - login_type = register_json["type"] - - is_application_server = login_type == LoginType.APPLICATION_SERVICE - is_using_shared_secret = login_type == LoginType.SHARED_SECRET - - can_register = ( - self.enable_registration - or is_application_server - or is_using_shared_secret - ) - if not can_register: - raise SynapseError(403, "Registration has been disabled") - - stages = { - LoginType.RECAPTCHA: self._do_recaptcha, - LoginType.PASSWORD: self._do_password, - LoginType.EMAIL_IDENTITY: self._do_email_identity, - LoginType.APPLICATION_SERVICE: self._do_app_service, - LoginType.SHARED_SECRET: self._do_shared_secret, - } - - session_info = self._get_session_info(request, session) - logger.debug("%s : session info %s request info %s", - login_type, session_info, register_json) - response = yield stages[login_type]( - request, - register_json, - session_info - ) - - if "access_token" not in response: - # isn't a final response - response["session"] = session_info["id"] - - defer.returnValue((200, response)) - except KeyError as e: - logger.exception(e) - raise SynapseError(400, "Missing JSON keys for login type %s." % ( - login_type, - )) - - def on_OPTIONS(self, request): - return (200, {}) - - def _get_session_info(self, request, session_id): - if not session_id: - # create a new session - while session_id is None or session_id in self.sessions: - session_id = stringutils.random_string(24) - self.sessions[session_id] = { - "id": session_id, - LoginType.EMAIL_IDENTITY: False, - LoginType.RECAPTCHA: False - } - - return self.sessions[session_id] - - def _save_session(self, session): - # TODO: Persistent storage - logger.debug("Saving session %s", session) - self.sessions[session["id"]] = session - - def _remove_session(self, session): - logger.debug("Removing session %s", session) - self.sessions.pop(session["id"]) - - @defer.inlineCallbacks - def _do_recaptcha(self, request, register_json, session): - if not self.hs.config.enable_registration_captcha: - raise SynapseError(400, "Captcha not required.") - - yield self._check_recaptcha(request, register_json, session) - - session[LoginType.RECAPTCHA] = True # mark captcha as done - self._save_session(session) - defer.returnValue({ - "next": [LoginType.PASSWORD, LoginType.EMAIL_IDENTITY] - }) - - @defer.inlineCallbacks - def _check_recaptcha(self, request, register_json, session): - if ("captcha_bypass_hmac" in register_json and - self.hs.config.captcha_bypass_secret): - if "user" not in register_json: - raise SynapseError(400, "Captcha bypass needs 'user'") - - want = hmac.new( - key=self.hs.config.captcha_bypass_secret, - msg=register_json["user"], - digestmod=sha1, - ).hexdigest() - - # str() because otherwise hmac complains that 'unicode' does not - # have the buffer interface - got = str(register_json["captcha_bypass_hmac"]) - - if compare_digest(want, got): - session["user"] = register_json["user"] - defer.returnValue(None) - else: - raise SynapseError( - 400, "Captcha bypass HMAC incorrect", - errcode=Codes.CAPTCHA_NEEDED - ) - - challenge = None - user_response = None - try: - challenge = register_json["challenge"] - user_response = register_json["response"] - except KeyError: - raise SynapseError(400, "Captcha response is required", - errcode=Codes.CAPTCHA_NEEDED) - - ip_addr = self.hs.get_ip_from_request(request) - - handler = self.handlers.registration_handler - yield handler.check_recaptcha( - ip_addr, - self.hs.config.recaptcha_private_key, - challenge, - user_response - ) - - @defer.inlineCallbacks - def _do_email_identity(self, request, register_json, session): - if (self.hs.config.enable_registration_captcha and - not session[LoginType.RECAPTCHA]): - raise SynapseError(400, "Captcha is required.") - - threepidCreds = register_json['threepidCreds'] - handler = self.handlers.registration_handler - logger.debug("Registering email. threepidcreds: %s" % (threepidCreds)) - yield handler.register_email(threepidCreds) - session["threepidCreds"] = threepidCreds # store creds for next stage - session[LoginType.EMAIL_IDENTITY] = True # mark email as done - self._save_session(session) - defer.returnValue({ - "next": LoginType.PASSWORD - }) - - @defer.inlineCallbacks - def _do_password(self, request, register_json, session): - if (self.hs.config.enable_registration_captcha and - not session[LoginType.RECAPTCHA]): - # captcha should've been done by this stage! - raise SynapseError(400, "Captcha is required.") - - if ("user" in session and "user" in register_json and - session["user"] != register_json["user"]): - raise SynapseError( - 400, "Cannot change user ID during registration" - ) - - password = register_json["password"].encode("utf-8") - desired_user_id = ( - register_json["user"].encode("utf-8") - if "user" in register_json else None - ) - - handler = self.handlers.registration_handler - (user_id, token) = yield handler.register( - localpart=desired_user_id, - password=password - ) - - if session[LoginType.EMAIL_IDENTITY]: - logger.debug("Binding emails %s to %s" % ( - session["threepidCreds"], user_id) - ) - yield handler.bind_emails(user_id, session["threepidCreds"]) - - result = { - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname, - } - self._remove_session(session) - defer.returnValue(result) - - @defer.inlineCallbacks - def _do_app_service(self, request, register_json, session): - as_token = self.auth.get_access_token_from_request(request) - - assert_params_in_dict(register_json, ["user"]) - user_localpart = register_json["user"].encode("utf-8") - - handler = self.handlers.registration_handler - user_id = yield handler.appservice_register( - user_localpart, as_token - ) - token = yield self.auth_handler.issue_access_token(user_id) - self._remove_session(session) - defer.returnValue({ - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname, - }) - - @defer.inlineCallbacks - def _do_shared_secret(self, request, register_json, session): - assert_params_in_dict(register_json, ["mac", "user", "password"]) - - if not self.hs.config.registration_shared_secret: - raise SynapseError(400, "Shared secret registration is not enabled") - - user = register_json["user"].encode("utf-8") - password = register_json["password"].encode("utf-8") - admin = register_json.get("admin", None) - - # Its important to check as we use null bytes as HMAC field separators - if b"\x00" in user: - raise SynapseError(400, "Invalid user") - if b"\x00" in password: - raise SynapseError(400, "Invalid password") - - # str() because otherwise hmac complains that 'unicode' does not - # have the buffer interface - got_mac = str(register_json["mac"]) - - want_mac = hmac.new( - key=self.hs.config.registration_shared_secret.encode(), - digestmod=sha1, - ) - want_mac.update(user) - want_mac.update(b"\x00") - want_mac.update(password) - want_mac.update(b"\x00") - want_mac.update(b"admin" if admin else b"notadmin") - want_mac = want_mac.hexdigest() - - if compare_digest(want_mac, got_mac): - handler = self.handlers.registration_handler - user_id, token = yield handler.register( - localpart=user.lower(), - password=password, - admin=bool(admin), - ) - self._remove_session(session) - defer.returnValue({ - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname, - }) - else: - raise SynapseError( - 403, "HMAC incorrect", - ) - - -class CreateUserRestServlet(ClientV1RestServlet): - """Handles user creation via a server-to-server interface - """ - - PATTERNS = client_path_patterns("/createUser$", releases=()) - - def __init__(self, hs): - super(CreateUserRestServlet, self).__init__(hs) - self.store = hs.get_datastore() - self.handlers = hs.get_handlers() - - @defer.inlineCallbacks - def on_POST(self, request): - user_json = parse_json_object_from_request(request) - - access_token = self.auth.get_access_token_from_request(request) - app_service = self.store.get_app_service_by_token( - access_token - ) - if not app_service: - raise SynapseError(403, "Invalid application service token.") - - requester = create_requester(app_service.sender) - - logger.debug("creating user: %s", user_json) - response = yield self._do_create(requester, user_json) - - defer.returnValue((200, response)) - - def on_OPTIONS(self, request): - return 403, {} - - @defer.inlineCallbacks - def _do_create(self, requester, user_json): - assert_params_in_dict(user_json, ["localpart", "displayname"]) - - localpart = user_json["localpart"].encode("utf-8") - displayname = user_json["displayname"].encode("utf-8") - password_hash = user_json["password_hash"].encode("utf-8") \ - if user_json.get("password_hash") else None - - handler = self.handlers.registration_handler - user_id, token = yield handler.get_or_create_user( - requester=requester, - localpart=localpart, - displayname=displayname, - password_hash=password_hash - ) - - defer.returnValue({ - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname, - }) - - -def register_servlets(hs, http_server): - RegisterRestServlet(hs).register(http_server) - CreateUserRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 3d62447854..b9512a2b61 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -832,10 +832,13 @@ def register_servlets(hs, http_server): RoomSendEventRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) - RoomInitialSyncRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) JoinedRoomsRestServlet(hs).register(http_server) RoomEventServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) + + +def register_deprecated_servlets(hs, http_server): + RoomInitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1_only/__init__.py b/synapse/rest/client/v1_only/__init__.py new file mode 100644 index 0000000000..936f902ace --- /dev/null +++ b/synapse/rest/client/v1_only/__init__.py @@ -0,0 +1,3 @@ +""" +REST APIs that are only used in v1 (the legacy API). +""" diff --git a/synapse/rest/client/v1_only/base.py b/synapse/rest/client/v1_only/base.py new file mode 100644 index 0000000000..9d4db7437c --- /dev/null +++ b/synapse/rest/client/v1_only/base.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains base REST classes for constructing client v1 servlets. +""" + +import re + +from synapse.api.urls import CLIENT_PREFIX + + +def v1_only_client_path_patterns(path_regex, include_in_unstable=True): + """Creates a regex compiled client path with the correct client path + prefix. + + Args: + path_regex (str): The regex string to match. This should NOT have a ^ + as this will be prefixed. + Returns: + list of SRE_Pattern + """ + patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)] + if include_in_unstable: + unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable") + patterns.append(re.compile("^" + unstable_prefix + path_regex)) + return patterns diff --git a/synapse/rest/client/v1_only/register.py b/synapse/rest/client/v1_only/register.py new file mode 100644 index 0000000000..3439c3c6d4 --- /dev/null +++ b/synapse/rest/client/v1_only/register.py @@ -0,0 +1,437 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains REST servlets to do with registration: /register""" +import hmac +import logging +from hashlib import sha1 + +from twisted.internet import defer + +import synapse.util.stringutils as stringutils +from synapse.api.constants import LoginType +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request +from synapse.rest.client.v1.base import ClientV1RestServlet +from synapse.types import create_requester + +from .base import v1_only_client_path_patterns + +logger = logging.getLogger(__name__) + + +# We ought to be using hmac.compare_digest() but on older pythons it doesn't +# exist. It's a _really minor_ security flaw to use plain string comparison +# because the timing attack is so obscured by all the other code here it's +# unlikely to make much difference +if hasattr(hmac, "compare_digest"): + compare_digest = hmac.compare_digest +else: + def compare_digest(a, b): + return a == b + + +class RegisterRestServlet(ClientV1RestServlet): + """Handles registration with the home server. + + This servlet is in control of the registration flow; the registration + handler doesn't have a concept of multi-stages or sessions. + """ + + PATTERNS = v1_only_client_path_patterns("/register$", include_in_unstable=False) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(RegisterRestServlet, self).__init__(hs) + # sessions are stored as: + # self.sessions = { + # "session_id" : { __session_dict__ } + # } + # TODO: persistent storage + self.sessions = {} + self.enable_registration = hs.config.enable_registration + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + self.handlers = hs.get_handlers() + + def on_GET(self, request): + + require_email = 'email' in self.hs.config.registrations_require_3pid + require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid + + flows = [] + if self.hs.config.enable_registration_captcha: + # only support the email-only flow if we don't require MSISDN 3PIDs + if not require_msisdn: + flows.extend([ + { + "type": LoginType.RECAPTCHA, + "stages": [ + LoginType.RECAPTCHA, + LoginType.EMAIL_IDENTITY, + LoginType.PASSWORD + ] + }, + ]) + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + flows.extend([ + { + "type": LoginType.RECAPTCHA, + "stages": [LoginType.RECAPTCHA, LoginType.PASSWORD] + } + ]) + else: + # only support the email-only flow if we don't require MSISDN 3PIDs + if require_email or not require_msisdn: + flows.extend([ + { + "type": LoginType.EMAIL_IDENTITY, + "stages": [ + LoginType.EMAIL_IDENTITY, LoginType.PASSWORD + ] + } + ]) + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + flows.extend([ + { + "type": LoginType.PASSWORD + } + ]) + return (200, {"flows": flows}) + + @defer.inlineCallbacks + def on_POST(self, request): + register_json = parse_json_object_from_request(request) + + session = (register_json["session"] + if "session" in register_json else None) + login_type = None + assert_params_in_dict(register_json, ["type"]) + + try: + login_type = register_json["type"] + + is_application_server = login_type == LoginType.APPLICATION_SERVICE + is_using_shared_secret = login_type == LoginType.SHARED_SECRET + + can_register = ( + self.enable_registration + or is_application_server + or is_using_shared_secret + ) + if not can_register: + raise SynapseError(403, "Registration has been disabled") + + stages = { + LoginType.RECAPTCHA: self._do_recaptcha, + LoginType.PASSWORD: self._do_password, + LoginType.EMAIL_IDENTITY: self._do_email_identity, + LoginType.APPLICATION_SERVICE: self._do_app_service, + LoginType.SHARED_SECRET: self._do_shared_secret, + } + + session_info = self._get_session_info(request, session) + logger.debug("%s : session info %s request info %s", + login_type, session_info, register_json) + response = yield stages[login_type]( + request, + register_json, + session_info + ) + + if "access_token" not in response: + # isn't a final response + response["session"] = session_info["id"] + + defer.returnValue((200, response)) + except KeyError as e: + logger.exception(e) + raise SynapseError(400, "Missing JSON keys for login type %s." % ( + login_type, + )) + + def on_OPTIONS(self, request): + return (200, {}) + + def _get_session_info(self, request, session_id): + if not session_id: + # create a new session + while session_id is None or session_id in self.sessions: + session_id = stringutils.random_string(24) + self.sessions[session_id] = { + "id": session_id, + LoginType.EMAIL_IDENTITY: False, + LoginType.RECAPTCHA: False + } + + return self.sessions[session_id] + + def _save_session(self, session): + # TODO: Persistent storage + logger.debug("Saving session %s", session) + self.sessions[session["id"]] = session + + def _remove_session(self, session): + logger.debug("Removing session %s", session) + self.sessions.pop(session["id"]) + + @defer.inlineCallbacks + def _do_recaptcha(self, request, register_json, session): + if not self.hs.config.enable_registration_captcha: + raise SynapseError(400, "Captcha not required.") + + yield self._check_recaptcha(request, register_json, session) + + session[LoginType.RECAPTCHA] = True # mark captcha as done + self._save_session(session) + defer.returnValue({ + "next": [LoginType.PASSWORD, LoginType.EMAIL_IDENTITY] + }) + + @defer.inlineCallbacks + def _check_recaptcha(self, request, register_json, session): + if ("captcha_bypass_hmac" in register_json and + self.hs.config.captcha_bypass_secret): + if "user" not in register_json: + raise SynapseError(400, "Captcha bypass needs 'user'") + + want = hmac.new( + key=self.hs.config.captcha_bypass_secret, + msg=register_json["user"], + digestmod=sha1, + ).hexdigest() + + # str() because otherwise hmac complains that 'unicode' does not + # have the buffer interface + got = str(register_json["captcha_bypass_hmac"]) + + if compare_digest(want, got): + session["user"] = register_json["user"] + defer.returnValue(None) + else: + raise SynapseError( + 400, "Captcha bypass HMAC incorrect", + errcode=Codes.CAPTCHA_NEEDED + ) + + challenge = None + user_response = None + try: + challenge = register_json["challenge"] + user_response = register_json["response"] + except KeyError: + raise SynapseError(400, "Captcha response is required", + errcode=Codes.CAPTCHA_NEEDED) + + ip_addr = self.hs.get_ip_from_request(request) + + handler = self.handlers.registration_handler + yield handler.check_recaptcha( + ip_addr, + self.hs.config.recaptcha_private_key, + challenge, + user_response + ) + + @defer.inlineCallbacks + def _do_email_identity(self, request, register_json, session): + if (self.hs.config.enable_registration_captcha and + not session[LoginType.RECAPTCHA]): + raise SynapseError(400, "Captcha is required.") + + threepidCreds = register_json['threepidCreds'] + handler = self.handlers.registration_handler + logger.debug("Registering email. threepidcreds: %s" % (threepidCreds)) + yield handler.register_email(threepidCreds) + session["threepidCreds"] = threepidCreds # store creds for next stage + session[LoginType.EMAIL_IDENTITY] = True # mark email as done + self._save_session(session) + defer.returnValue({ + "next": LoginType.PASSWORD + }) + + @defer.inlineCallbacks + def _do_password(self, request, register_json, session): + if (self.hs.config.enable_registration_captcha and + not session[LoginType.RECAPTCHA]): + # captcha should've been done by this stage! + raise SynapseError(400, "Captcha is required.") + + if ("user" in session and "user" in register_json and + session["user"] != register_json["user"]): + raise SynapseError( + 400, "Cannot change user ID during registration" + ) + + password = register_json["password"].encode("utf-8") + desired_user_id = ( + register_json["user"].encode("utf-8") + if "user" in register_json else None + ) + + handler = self.handlers.registration_handler + (user_id, token) = yield handler.register( + localpart=desired_user_id, + password=password + ) + + if session[LoginType.EMAIL_IDENTITY]: + logger.debug("Binding emails %s to %s" % ( + session["threepidCreds"], user_id) + ) + yield handler.bind_emails(user_id, session["threepidCreds"]) + + result = { + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname, + } + self._remove_session(session) + defer.returnValue(result) + + @defer.inlineCallbacks + def _do_app_service(self, request, register_json, session): + as_token = self.auth.get_access_token_from_request(request) + + assert_params_in_dict(register_json, ["user"]) + user_localpart = register_json["user"].encode("utf-8") + + handler = self.handlers.registration_handler + user_id = yield handler.appservice_register( + user_localpart, as_token + ) + token = yield self.auth_handler.issue_access_token(user_id) + self._remove_session(session) + defer.returnValue({ + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname, + }) + + @defer.inlineCallbacks + def _do_shared_secret(self, request, register_json, session): + assert_params_in_dict(register_json, ["mac", "user", "password"]) + + if not self.hs.config.registration_shared_secret: + raise SynapseError(400, "Shared secret registration is not enabled") + + user = register_json["user"].encode("utf-8") + password = register_json["password"].encode("utf-8") + admin = register_json.get("admin", None) + + # Its important to check as we use null bytes as HMAC field separators + if b"\x00" in user: + raise SynapseError(400, "Invalid user") + if b"\x00" in password: + raise SynapseError(400, "Invalid password") + + # str() because otherwise hmac complains that 'unicode' does not + # have the buffer interface + got_mac = str(register_json["mac"]) + + want_mac = hmac.new( + key=self.hs.config.registration_shared_secret.encode(), + digestmod=sha1, + ) + want_mac.update(user) + want_mac.update(b"\x00") + want_mac.update(password) + want_mac.update(b"\x00") + want_mac.update(b"admin" if admin else b"notadmin") + want_mac = want_mac.hexdigest() + + if compare_digest(want_mac, got_mac): + handler = self.handlers.registration_handler + user_id, token = yield handler.register( + localpart=user.lower(), + password=password, + admin=bool(admin), + ) + self._remove_session(session) + defer.returnValue({ + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname, + }) + else: + raise SynapseError( + 403, "HMAC incorrect", + ) + + +class CreateUserRestServlet(ClientV1RestServlet): + """Handles user creation via a server-to-server interface + """ + + PATTERNS = v1_only_client_path_patterns("/createUser$") + + def __init__(self, hs): + super(CreateUserRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + self.handlers = hs.get_handlers() + + @defer.inlineCallbacks + def on_POST(self, request): + user_json = parse_json_object_from_request(request) + + access_token = self.auth.get_access_token_from_request(request) + app_service = self.store.get_app_service_by_token( + access_token + ) + if not app_service: + raise SynapseError(403, "Invalid application service token.") + + requester = create_requester(app_service.sender) + + logger.debug("creating user: %s", user_json) + response = yield self._do_create(requester, user_json) + + defer.returnValue((200, response)) + + def on_OPTIONS(self, request): + return 403, {} + + @defer.inlineCallbacks + def _do_create(self, requester, user_json): + assert_params_in_dict(user_json, ["localpart", "displayname"]) + + localpart = user_json["localpart"].encode("utf-8") + displayname = user_json["displayname"].encode("utf-8") + password_hash = user_json["password_hash"].encode("utf-8") \ + if user_json.get("password_hash") else None + + handler = self.handlers.registration_handler + user_id, token = yield handler.get_or_create_user( + requester=requester, + localpart=localpart, + displayname=displayname, + password_hash=password_hash + ) + + defer.returnValue({ + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname, + }) + + +def register_servlets(hs, http_server): + RegisterRestServlet(hs).register(http_server) + CreateUserRestServlet(hs).register(http_server) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index a5af36a99c..50418153fa 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -14,100 +14,30 @@ # limitations under the License. """ Tests REST events for /events paths.""" + from mock import Mock, NonCallableMock +from six import PY3 -# twisted imports from twisted.internet import defer -import synapse.rest.client.v1.events -import synapse.rest.client.v1.register -import synapse.rest.client.v1.room - -from tests import unittest - from ....utils import MockHttpResource, setup_test_homeserver from .utils import RestTestCase PATH_PREFIX = "/_matrix/client/api/v1" -class EventStreamPaginationApiTestCase(unittest.TestCase): - """ Tests event streaming query parameters and start/end keys used in the - Pagination stream API. """ - user_id = "sid1" - - def setUp(self): - # configure stream and inject items - pass - - def tearDown(self): - pass - - def TODO_test_long_poll(self): - # stream from 'end' key, send (self+other) message, expect message. - - # stream from 'END', send (self+other) message, expect message. - - # stream from 'end' key, send (self+other) topic, expect topic. - - # stream from 'END', send (self+other) topic, expect topic. - - # stream from 'end' key, send (self+other) invite, expect invite. - - # stream from 'END', send (self+other) invite, expect invite. - - pass - - def TODO_test_stream_forward(self): - # stream from START, expect injected items - - # stream from 'start' key, expect same content - - # stream from 'end' key, expect nothing - - # stream from 'END', expect nothing - - # The following is needed for cases where content is removed e.g. you - # left a room, so the token you're streaming from is > the one that - # would be returned naturally from START>END. - # stream from very new token (higher than end key), expect same token - # returned as end key - pass - - def TODO_test_limits(self): - # stream from a key, expect limit_num items - - # stream from START, expect limit_num items - - pass - - def TODO_test_range(self): - # stream from key to key, expect X items - - # stream from key to END, expect X items - - # stream from START to key, expect X items - - # stream from START to END, expect all items - pass - - def TODO_test_direction(self): - # stream from END to START and fwds, expect newest first - - # stream from END to START and bwds, expect oldest first - - # stream from START to END and fwds, expect oldest first - - # stream from START to END and bwds, expect newest first - - pass - - class EventStreamPermissionsTestCase(RestTestCase): """ Tests event streaming (GET /events). """ + if PY3: + skip = "Skip on Py3 until ported to use not V1 only register." + @defer.inlineCallbacks def setUp(self): + import synapse.rest.client.v1.events + import synapse.rest.client.v1_only.register + import synapse.rest.client.v1.room + self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) hs = yield setup_test_homeserver( @@ -125,7 +55,7 @@ class EventStreamPermissionsTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - synapse.rest.client.v1.register.register_servlets(hs, self.mock_resource) + synapse.rest.client.v1_only.register.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index f15fb36213..83a23cd8fe 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -16,11 +16,12 @@ import json from mock import Mock +from six import PY3 from twisted.test.proto_helpers import MemoryReactorClock from synapse.http.server import JsonResource -from synapse.rest.client.v1.register import register_servlets +from synapse.rest.client.v1_only.register import register_servlets from synapse.util import Clock from tests import unittest @@ -31,6 +32,8 @@ class CreateUserServletTestCase(unittest.TestCase): """ Tests for CreateUserRestServlet. """ + if PY3: + skip = "Not ported to Python 3." def setUp(self): self.registration_handler = Mock() diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 6b5764095e..00fc796787 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -20,7 +20,6 @@ import json from mock import Mock, NonCallableMock from six.moves.urllib import parse as urlparse -# twisted imports from twisted.internet import defer import synapse.rest.client.v1.room @@ -86,6 +85,7 @@ class RoomBase(unittest.TestCase): self.resource = JsonResource(self.hs) synapse.rest.client.v1.room.register_servlets(self.hs, self.resource) + synapse.rest.client.v1.room.register_deprecated_servlets(self.hs, self.resource) self.helper = RestHelper(self.hs, self.resource, self.user_id) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index 5ea9cc825f..e890f0feac 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -21,8 +21,12 @@ from synapse.types import UserID from synapse.util import Clock from tests import unittest -from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock -from tests.server import make_request, setup_test_homeserver, wait_until_result +from tests.server import ( + ThreadedMemoryReactorClock as MemoryReactorClock, + make_request, + setup_test_homeserver, + wait_until_result, +) PATH_PREFIX = "/_matrix/client/v2_alpha" diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 704cf97a40..03ec3993b2 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -20,8 +20,12 @@ from synapse.types import UserID from synapse.util import Clock from tests import unittest -from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock -from tests.server import make_request, setup_test_homeserver, wait_until_result +from tests.server import ( + ThreadedMemoryReactorClock as MemoryReactorClock, + make_request, + setup_test_homeserver, + wait_until_result, +) PATH_PREFIX = "/_matrix/client/v2_alpha" -- cgit 1.4.1 From e1a237eaabf0ba37f242897700f9bf00729976b8 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Fri, 20 Jul 2018 22:41:13 +1000 Subject: Admin API for creating new users (#3415) --- changelog.d/3415.misc | 0 docs/admin_api/register_api.rst | 63 ++++++++ scripts/register_new_matrix_user | 32 +++- synapse/rest/client/v1/admin.py | 122 +++++++++++++++ synapse/secrets.py | 42 +++++ synapse/server.py | 5 + tests/rest/client/v1/test_admin.py | 305 +++++++++++++++++++++++++++++++++++++ tests/utils.py | 3 + 8 files changed, 569 insertions(+), 3 deletions(-) create mode 100644 changelog.d/3415.misc create mode 100644 docs/admin_api/register_api.rst create mode 100644 synapse/secrets.py create mode 100644 tests/rest/client/v1/test_admin.py (limited to 'synapse/rest/client') diff --git a/changelog.d/3415.misc b/changelog.d/3415.misc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/admin_api/register_api.rst b/docs/admin_api/register_api.rst new file mode 100644 index 0000000000..209cd140fd --- /dev/null +++ b/docs/admin_api/register_api.rst @@ -0,0 +1,63 @@ +Shared-Secret Registration +========================== + +This API allows for the creation of users in an administrative and +non-interactive way. This is generally used for bootstrapping a Synapse +instance with administrator accounts. + +To authenticate yourself to the server, you will need both the shared secret +(``registration_shared_secret`` in the homeserver configuration), and a +one-time nonce. If the registration shared secret is not configured, this API +is not enabled. + +To fetch the nonce, you need to request one from the API:: + + > GET /_matrix/client/r0/admin/register + + < {"nonce": "thisisanonce"} + +Once you have the nonce, you can make a ``POST`` to the same URL with a JSON +body containing the nonce, username, password, whether they are an admin +(optional, False by default), and a HMAC digest of the content. + +As an example:: + + > POST /_matrix/client/r0/admin/register + > { + "nonce": "thisisanonce", + "username": "pepper_roni", + "password": "pizza", + "admin": true, + "mac": "mac_digest_here" + } + + < { + "access_token": "token_here", + "user_id": "@pepper_roni@test", + "home_server": "test", + "device_id": "device_id_here" + } + +The MAC is the hex digest output of the HMAC-SHA1 algorithm, with the key being +the shared secret and the content being the nonce, user, password, and either +the string "admin" or "notadmin", each separated by NULs. For an example of +generation in Python:: + + import hmac, hashlib + + def generate_mac(nonce, user, password, admin=False): + + mac = hmac.new( + key=shared_secret, + digestmod=hashlib.sha1, + ) + + mac.update(nonce.encode('utf8')) + mac.update(b"\x00") + mac.update(user.encode('utf8')) + mac.update(b"\x00") + mac.update(password.encode('utf8')) + mac.update(b"\x00") + mac.update(b"admin" if admin else b"notadmin") + + return mac.hexdigest() diff --git a/scripts/register_new_matrix_user b/scripts/register_new_matrix_user index 12ed20d623..8c3d429351 100755 --- a/scripts/register_new_matrix_user +++ b/scripts/register_new_matrix_user @@ -26,11 +26,37 @@ import yaml def request_registration(user, password, server_location, shared_secret, admin=False): + req = urllib2.Request( + "%s/_matrix/client/r0/admin/register" % (server_location,), + headers={'Content-Type': 'application/json'} + ) + + try: + if sys.version_info[:3] >= (2, 7, 9): + # As of version 2.7.9, urllib2 now checks SSL certs + import ssl + f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) + else: + f = urllib2.urlopen(req) + body = f.read() + f.close() + nonce = json.loads(body)["nonce"] + except urllib2.HTTPError as e: + print "ERROR! Received %d %s" % (e.code, e.reason,) + if 400 <= e.code < 500: + if e.info().type == "application/json": + resp = json.load(e) + if "error" in resp: + print resp["error"] + sys.exit(1) + mac = hmac.new( key=shared_secret, digestmod=hashlib.sha1, ) + mac.update(nonce) + mac.update("\x00") mac.update(user) mac.update("\x00") mac.update(password) @@ -40,10 +66,10 @@ def request_registration(user, password, server_location, shared_secret, admin=F mac = mac.hexdigest() data = { - "user": user, + "nonce": nonce, + "username": user, "password": password, "mac": mac, - "type": "org.matrix.login.shared_secret", "admin": admin, } @@ -52,7 +78,7 @@ def request_registration(user, password, server_location, shared_secret, admin=F print "Sending registration request..." req = urllib2.Request( - "%s/_matrix/client/api/v1/register" % (server_location,), + "%s/_matrix/client/r0/admin/register" % (server_location,), data=json.dumps(data), headers={'Content-Type': 'application/json'} ) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 2dc50e582b..9e9c175970 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hashlib +import hmac import logging from six.moves import http_client @@ -63,6 +65,125 @@ class UsersRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) +class UserRegisterServlet(ClientV1RestServlet): + """ + Attributes: + NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted + nonces (dict[str, int]): The nonces that we will accept. A dict of + nonce to the time it was generated, in int seconds. + """ + PATTERNS = client_path_patterns("/admin/register") + NONCE_TIMEOUT = 60 + + def __init__(self, hs): + super(UserRegisterServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + self.reactor = hs.get_reactor() + self.nonces = {} + self.hs = hs + + def _clear_old_nonces(self): + """ + Clear out old nonces that are older than NONCE_TIMEOUT. + """ + now = int(self.reactor.seconds()) + + for k, v in list(self.nonces.items()): + if now - v > self.NONCE_TIMEOUT: + del self.nonces[k] + + def on_GET(self, request): + """ + Generate a new nonce. + """ + self._clear_old_nonces() + + nonce = self.hs.get_secrets().token_hex(64) + self.nonces[nonce] = int(self.reactor.seconds()) + return (200, {"nonce": nonce.encode('ascii')}) + + @defer.inlineCallbacks + def on_POST(self, request): + self._clear_old_nonces() + + if not self.hs.config.registration_shared_secret: + raise SynapseError(400, "Shared secret registration is not enabled") + + body = parse_json_object_from_request(request) + + if "nonce" not in body: + raise SynapseError( + 400, "nonce must be specified", errcode=Codes.BAD_JSON, + ) + + nonce = body["nonce"] + + if nonce not in self.nonces: + raise SynapseError( + 400, "unrecognised nonce", + ) + + # Delete the nonce, so it can't be reused, even if it's invalid + del self.nonces[nonce] + + if "username" not in body: + raise SynapseError( + 400, "username must be specified", errcode=Codes.BAD_JSON, + ) + else: + if (not isinstance(body['username'], str) or len(body['username']) > 512): + raise SynapseError(400, "Invalid username") + + username = body["username"].encode("utf-8") + if b"\x00" in username: + raise SynapseError(400, "Invalid username") + + if "password" not in body: + raise SynapseError( + 400, "password must be specified", errcode=Codes.BAD_JSON, + ) + else: + if (not isinstance(body['password'], str) or len(body['password']) > 512): + raise SynapseError(400, "Invalid password") + + password = body["password"].encode("utf-8") + if b"\x00" in password: + raise SynapseError(400, "Invalid password") + + admin = body.get("admin", None) + got_mac = body["mac"] + + want_mac = hmac.new( + key=self.hs.config.registration_shared_secret.encode(), + digestmod=hashlib.sha1, + ) + want_mac.update(nonce) + want_mac.update(b"\x00") + want_mac.update(username) + want_mac.update(b"\x00") + want_mac.update(password) + want_mac.update(b"\x00") + want_mac.update(b"admin" if admin else b"notadmin") + want_mac = want_mac.hexdigest() + + if not hmac.compare_digest(want_mac, got_mac): + raise SynapseError( + 403, "HMAC incorrect", + ) + + # Reuse the parts of RegisterRestServlet to reduce code duplication + from synapse.rest.client.v2_alpha.register import RegisterRestServlet + register = RegisterRestServlet(self.hs) + + (user_id, _) = yield register.registration_handler.register( + localpart=username.lower(), password=password, admin=bool(admin), + generate_token=False, + ) + + result = yield register._create_registration_details(user_id, body) + defer.returnValue((200, result)) + + class WhoisRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/whois/(?P[^/]*)") @@ -614,3 +735,4 @@ def register_servlets(hs, http_server): ShutdownRoomRestServlet(hs).register(http_server) QuarantineMediaInRoom(hs).register(http_server) ListMediaInRoom(hs).register(http_server) + UserRegisterServlet(hs).register(http_server) diff --git a/synapse/secrets.py b/synapse/secrets.py new file mode 100644 index 0000000000..f397daaa5e --- /dev/null +++ b/synapse/secrets.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Injectable secrets module for Synapse. + +See https://docs.python.org/3/library/secrets.html#module-secrets for the API +used in Python 3.6, and the API emulated in Python 2.7. +""" + +import six + +if six.PY3: + import secrets + + def Secrets(): + return secrets + + +else: + + import os + import binascii + + class Secrets(object): + def token_bytes(self, nbytes=32): + return os.urandom(nbytes) + + def token_hex(self, nbytes=32): + return binascii.hexlify(self.token_bytes(nbytes)) diff --git a/synapse/server.py b/synapse/server.py index 92bea96c5c..fd4f992258 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -74,6 +74,7 @@ from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, ) +from synapse.secrets import Secrets from synapse.server_notices.server_notices_manager import ServerNoticesManager from synapse.server_notices.server_notices_sender import ServerNoticesSender from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender @@ -158,6 +159,7 @@ class HomeServer(object): 'groups_server_handler', 'groups_attestation_signing', 'groups_attestation_renewer', + 'secrets', 'spam_checker', 'room_member_handler', 'federation_registry', @@ -405,6 +407,9 @@ class HomeServer(object): def build_groups_attestation_renewer(self): return GroupAttestionRenewer(self) + def build_secrets(self): + return Secrets() + def build_spam_checker(self): return SpamChecker(self) diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py new file mode 100644 index 0000000000..8c90145601 --- /dev/null +++ b/tests/rest/client/v1/test_admin.py @@ -0,0 +1,305 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import hmac +import json + +from mock import Mock + +from synapse.http.server import JsonResource +from synapse.rest.client.v1.admin import register_servlets +from synapse.util import Clock + +from tests import unittest +from tests.server import ( + ThreadedMemoryReactorClock, + make_request, + render, + setup_test_homeserver, +) + + +class UserRegisterTestCase(unittest.TestCase): + def setUp(self): + + self.clock = ThreadedMemoryReactorClock() + self.hs_clock = Clock(self.clock) + self.url = "/_matrix/client/r0/admin/register" + + self.registration_handler = Mock() + self.identity_handler = Mock() + self.login_handler = Mock() + self.device_handler = Mock() + self.device_handler.check_device_registered = Mock(return_value="FAKE") + + self.datastore = Mock(return_value=Mock()) + self.datastore.get_current_state_deltas = Mock(return_value=[]) + + self.secrets = Mock() + + self.hs = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.clock + ) + + self.hs.config.registration_shared_secret = u"shared" + + self.hs.get_media_repository = Mock() + self.hs.get_deactivate_account_handler = Mock() + + self.resource = JsonResource(self.hs) + register_servlets(self.hs, self.resource) + + def test_disabled(self): + """ + If there is no shared secret, registration through this method will be + prevented. + """ + self.hs.config.registration_shared_secret = None + + request, channel = make_request("POST", self.url, b'{}') + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + 'Shared secret registration is not enabled', channel.json_body["error"] + ) + + def test_get_nonce(self): + """ + Calling GET on the endpoint will return a randomised nonce, using the + homeserver's secrets provider. + """ + secrets = Mock() + secrets.token_hex = Mock(return_value="abcd") + + self.hs.get_secrets = Mock(return_value=secrets) + + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + + self.assertEqual(channel.json_body, {"nonce": "abcd"}) + + def test_expired_nonce(self): + """ + Calling GET on the endpoint will return a randomised nonce, which will + only last for SALT_TIMEOUT (60s). + """ + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + nonce = channel.json_body["nonce"] + + # 59 seconds + self.clock.advance(59) + + body = json.dumps({"nonce": nonce}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('username must be specified', channel.json_body["error"]) + + # 61 seconds + self.clock.advance(2) + + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('unrecognised nonce', channel.json_body["error"]) + + def test_register_incorrect_nonce(self): + """ + Only the provided nonce can be used, as it's checked in the MAC. + """ + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + ).encode('utf8') + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("HMAC incorrect", channel.json_body["error"]) + + def test_register_correct_nonce(self): + """ + When the correct nonce is provided, and the right key is provided, the + user is registered. + """ + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + ).encode('utf8') + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["user_id"]) + + def test_nonce_reuse(self): + """ + A valid unrecognised nonce. + """ + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + ).encode('utf8') + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["user_id"]) + + # Now, try and reuse it + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('unrecognised nonce', channel.json_body["error"]) + + def test_missing_parts(self): + """ + Synapse will complain if you don't give nonce, username, password, and + mac. Admin is optional. Additional checks are done for length and + type. + """ + def nonce(): + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + return channel.json_body["nonce"] + + # + # Nonce check + # + + # Must be present + body = json.dumps({}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('nonce must be specified', channel.json_body["error"]) + + # + # Username checks + # + + # Must be present + body = json.dumps({"nonce": nonce()}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('username must be specified', channel.json_body["error"]) + + # Must be a string + body = json.dumps({"nonce": nonce(), "username": 1234}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid username', channel.json_body["error"]) + + # Must not have null bytes + body = json.dumps({"nonce": nonce(), "username": b"abcd\x00"}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid username', channel.json_body["error"]) + + # Must not have null bytes + body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid username', channel.json_body["error"]) + + # + # Username checks + # + + # Must be present + body = json.dumps({"nonce": nonce(), "username": "a"}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('password must be specified', channel.json_body["error"]) + + # Must be a string + body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid password', channel.json_body["error"]) + + # Must not have null bytes + body = json.dumps({"nonce": nonce(), "username": "a", "password": b"abcd\x00"}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid password', channel.json_body["error"]) + + # Super long + body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid password', channel.json_body["error"]) diff --git a/tests/utils.py b/tests/utils.py index e488238bb3..c3dbff8507 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -71,6 +71,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None config.user_directory_search_all_users = False config.user_consent_server_notice_content = None config.block_events_without_consent_error = None + config.media_storage_providers = [] + config.auto_join_rooms = [] # disable user directory updates, because they get done in the # background, which upsets the test runner. @@ -136,6 +138,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None database_engine=db_engine, room_list_handler=object(), tls_server_context_factory=Mock(), + reactor=reactor, **kargs ) -- cgit 1.4.1