diff options
author | David Baker <dave@matrix.org> | 2016-08-11 14:09:13 +0100 |
---|---|---|
committer | David Baker <dave@matrix.org> | 2016-08-11 14:09:13 +0100 |
commit | b4ecf0b886c67437901e0af457c5f801ebde9a72 (patch) | |
tree | ef66b0684edcfeb4ad68d20375641f4654393f44 /synapse/rest/client | |
parent | Include the ts the notif was received at (diff) | |
parent | Merge pull request #1003 from matrix-org/erikj/redaction_prev_content (diff) | |
download | synapse-b4ecf0b886c67437901e0af457c5f801ebde9a72.tar.xz |
Merge remote-tracking branch 'origin/develop' into dbkr/notifications_api
Diffstat (limited to 'synapse/rest/client')
-rw-r--r-- | synapse/rest/client/v1/admin.py | 77 | ||||
-rw-r--r-- | synapse/rest/client/v1/base.py | 4 | ||||
-rw-r--r-- | synapse/rest/client/v1/events.py | 45 | ||||
-rw-r--r-- | synapse/rest/client/v1/login.py | 235 | ||||
-rw-r--r-- | synapse/rest/client/v1/push_rule.py | 6 | ||||
-rw-r--r-- | synapse/rest/client/v1/pusher.py | 57 | ||||
-rw-r--r-- | synapse/rest/client/v1/register.py | 33 | ||||
-rw-r--r-- | synapse/rest/client/v1/room.py | 25 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/_base.py | 13 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/account.py | 118 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/auth.py | 2 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/devices.py | 100 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/keys.py | 102 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/register.py | 345 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/tokenrefresh.py | 12 | ||||
-rw-r--r-- | synapse/rest/client/versions.py | 6 |
16 files changed, 800 insertions, 380 deletions
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index aa05b3f023..b0cb31a448 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -46,5 +46,82 @@ class WhoisRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) +class PurgeMediaCacheRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/admin/purge_media_cache") + + def __init__(self, hs): + self.media_repository = hs.get_media_repository() + super(PurgeMediaCacheRestServlet, self).__init__(hs) + + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + before_ts = request.args.get("before_ts", None) + if not before_ts: + raise SynapseError(400, "Missing 'before_ts' arg") + + logger.info("before_ts: %r", before_ts[0]) + + try: + before_ts = int(before_ts[0]) + except Exception: + raise SynapseError(400, "Invalid 'before_ts' arg") + + ret = yield self.media_repository.delete_old_remote_media(before_ts) + + defer.returnValue((200, ret)) + + +class PurgeHistoryRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns( + "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" + ) + + @defer.inlineCallbacks + def on_POST(self, request, room_id, event_id): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + yield self.handlers.message_handler.purge_history(room_id, event_id) + + defer.returnValue((200, {})) + + +class DeactivateAccountRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)") + + def __init__(self, hs): + self.store = hs.get_datastore() + super(DeactivateAccountRestServlet, self).__init__(hs) + + @defer.inlineCallbacks + def on_POST(self, request, target_user_id): + UserID.from_string(target_user_id) + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + # FIXME: Theoretically there is a race here wherein user resets password + # using threepid. + yield self.store.user_delete_access_tokens(target_user_id) + yield self.store.user_delete_threepids(target_user_id) + yield self.store.user_set_password_hash(target_user_id, None) + + defer.returnValue((200, {})) + + def register_servlets(hs, http_server): WhoisRestServlet(hs).register(http_server) + PurgeMediaCacheRestServlet(hs).register(http_server) + DeactivateAccountRestServlet(hs).register(http_server) + PurgeHistoryRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 1c020b7e2c..96b49b01f2 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet): """ def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ self.hs = hs self.handlers = hs.get_handlers() self.builder_factory = hs.get_event_builder_factory() diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index d1afa0f0d5..498bb9e18a 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -45,30 +45,27 @@ class EventStreamRestServlet(ClientV1RestServlet): raise SynapseError(400, "Guest users must specify room_id param") if "room_id" in request.args: room_id = request.args["room_id"][0] - try: - handler = self.handlers.event_stream_handler - pagin_config = PaginationConfig.from_request(request) - timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS - if "timeout" in request.args: - try: - timeout = int(request.args["timeout"][0]) - except ValueError: - raise SynapseError(400, "timeout must be in milliseconds.") - - as_client_event = "raw" not in request.args - - chunk = yield handler.get_stream( - requester.user.to_string(), - pagin_config, - timeout=timeout, - as_client_event=as_client_event, - affect_presence=(not is_guest), - room_id=room_id, - is_guest=is_guest, - ) - except: - logger.exception("Event stream failed") - raise + + handler = self.handlers.event_stream_handler + pagin_config = PaginationConfig.from_request(request) + timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS + if "timeout" in request.args: + try: + timeout = int(request.args["timeout"][0]) + except ValueError: + raise SynapseError(400, "timeout must be in milliseconds.") + + as_client_event = "raw" not in request.args + + chunk = yield handler.get_stream( + requester.user.to_string(), + pagin_config, + timeout=timeout, + as_client_event=as_client_event, + affect_presence=(not is_guest), + room_id=room_id, + is_guest=is_guest, + ) defer.returnValue((200, chunk)) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 3b5544851b..b31e27f7b3 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -54,10 +54,8 @@ class LoginRestServlet(ClientV1RestServlet): self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm self.cas_enabled = hs.config.cas_enabled - self.cas_server_url = hs.config.cas_server_url - self.cas_required_attributes = hs.config.cas_required_attributes - self.servername = hs.config.server_name - self.http_client = hs.get_simple_http_client() + self.auth_handler = self.hs.get_auth_handler() + self.device_handler = self.hs.get_device_handler() def on_GET(self, request): flows = [] @@ -108,17 +106,6 @@ class LoginRestServlet(ClientV1RestServlet): LoginRestServlet.JWT_TYPE): result = yield self.do_jwt_login(login_submission) defer.returnValue(result) - # TODO Delete this after all CAS clients switch to token login instead - elif self.cas_enabled and (login_submission["type"] == - LoginRestServlet.CAS_TYPE): - uri = "%s/proxyValidate" % (self.cas_server_url,) - args = { - "ticket": login_submission["ticket"], - "service": login_submission["service"] - } - body = yield self.http_client.get_raw(uri, args) - result = yield self.do_cas_login(body) - defer.returnValue(result) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: result = yield self.do_token_login(login_submission) defer.returnValue(result) @@ -143,16 +130,24 @@ class LoginRestServlet(ClientV1RestServlet): user_id, self.hs.hostname ).to_string() - auth_handler = self.handlers.auth_handler - user_id, access_token, refresh_token = yield auth_handler.login_with_password( + auth_handler = self.auth_handler + user_id = yield auth_handler.validate_password_login( user_id=user_id, - password=login_submission["password"]) - + password=login_submission["password"], + ) + device_id = yield self._register_device(user_id, login_submission) + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) + ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -160,65 +155,27 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def do_token_login(self, login_submission): token = login_submission['token'] - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + device_id = yield self._register_device(user_id, login_submission) + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) - # TODO Delete this after all CAS clients switch to token login instead - @defer.inlineCallbacks - def do_cas_login(self, cas_response_body): - user, attributes = self.parse_cas_response(cas_response_body) - - for required_attribute, required_value in self.cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in attributes: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - # Also need to check value - if required_value is not None: - actual_value = attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "refresh_token": refresh_token, - "home_server": self.hs.hostname, - } - - else: - user_id, access_token = ( - yield self.handlers.registration_handler.register(localpart=user) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "home_server": self.hs.hostname, - } - - defer.returnValue((200, result)) - @defer.inlineCallbacks def do_jwt_login(self, login_submission): token = login_submission.get("token", None) @@ -243,19 +200,28 @@ class LoginRestServlet(ClientV1RestServlet): raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + auth_handler = self.auth_handler + registered_user_id = yield auth_handler.check_user_exists(user_id) + if registered_user_id: + device_id = yield self._register_device( + registered_user_id, login_submission + ) + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id( + registered_user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { - "user_id": user_id, # may have changed + "user_id": registered_user_id, "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, } else: + # TODO: we should probably check that the register isn't going + # to fonx/change our user_id before registering the device + device_id = yield self._register_device(user_id, login_submission) user_id, access_token = ( yield self.handlers.registration_handler.register(localpart=user) ) @@ -267,32 +233,25 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - # TODO Delete this after all CAS clients switch to token login instead - def parse_cas_response(self, cas_response_body): - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not root[0].tag.endswith("authenticationSuccess"): - raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - attributes = {} - for attribute in child: - # ElementTree library expands the namespace in attribute tags - # to the full URL of the namespace. - # See (https://docs.python.org/2/library/xml.etree.elementtree.html) - # We don't care about namespace here and it will always be encased in - # curly braces, so we remove them. - if "}" in attribute.tag: - attributes[attribute.tag.split("}")[1]] = attribute.text - else: - attributes[attribute.tag] = attribute.text - if user is None or attributes is None: - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - - return (user, attributes) + def _register_device(self, user_id, login_submission): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) login_submission: dictionary supplied to /login call, from + which we pull device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get( + "initial_device_display_name") + return self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) class SAML2RestServlet(ClientV1RestServlet): @@ -338,18 +297,6 @@ class SAML2RestServlet(ClientV1RestServlet): defer.returnValue((200, {"status": "not_authenticated"})) -# TODO Delete this after all CAS clients switch to token login instead -class CasRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/login/cas", releases=()) - - def __init__(self, hs): - super(CasRestServlet, self).__init__(hs) - self.cas_server_url = hs.config.cas_server_url - - def on_GET(self, request): - return (200, {"serverUrl": self.cas_server_url}) - - class CasRedirectServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login/cas/redirect", releases=()) @@ -381,6 +328,7 @@ class CasTicketServlet(ClientV1RestServlet): self.cas_server_url = hs.config.cas_server_url self.cas_service_url = hs.config.cas_service_url self.cas_required_attributes = hs.config.cas_required_attributes + self.auth_handler = hs.get_auth_handler() @defer.inlineCallbacks def on_GET(self, request): @@ -412,14 +360,14 @@ class CasTicketServlet(ClientV1RestServlet): raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if not user_exists: - user_id, _ = ( + auth_handler = self.auth_handler + registered_user_id = yield auth_handler.check_user_exists(user_id) + if not registered_user_id: + registered_user_id, _ = ( yield self.handlers.registration_handler.register(localpart=user) ) - login_token = auth_handler.generate_short_term_login_token(user_id) + login_token = auth_handler.generate_short_term_login_token(registered_user_id) redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, login_token) request.redirect(redirect_url) @@ -433,30 +381,39 @@ class CasTicketServlet(ClientV1RestServlet): return urlparse.urlunparse(url_parts) def parse_cas_response(self, cas_response_body): - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not root[0].tag.endswith("authenticationSuccess"): - raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - attributes = {} - for attribute in child: - # ElementTree library expands the namespace in attribute tags - # to the full URL of the namespace. - # See (https://docs.python.org/2/library/xml.etree.elementtree.html) - # We don't care about namespace here and it will always be encased in - # curly braces, so we remove them. - if "}" in attribute.tag: - attributes[attribute.tag.split("}")[1]] = attribute.text - else: - attributes[attribute.tag] = attribute.text - if user is None or attributes is None: - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - - return (user, attributes) + user = None + attributes = None + try: + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise Exception("root of CAS response is not serviceResponse") + success = (root[0].tag.endswith("authenticationSuccess")) + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + attributes = {} + for attribute in child: + # ElementTree library expands the namespace in + # attribute tags to the full URL of the namespace. + # We don't care about namespace here and it will always + # be encased in curly braces, so we remove them. + tag = attribute.tag + if "}" in tag: + tag = tag.split("}")[1] + attributes[tag] = attribute.text + if user is None: + raise Exception("CAS response does not contain user") + if attributes is None: + raise Exception("CAS response does not contain attributes") + except Exception: + logger.error("Error parsing CAS response", exc_info=1) + raise LoginError(401, "Invalid CAS response", + errcode=Codes.UNAUTHORIZED) + if not success: + raise LoginError(401, "Unsuccessful CAS response", + errcode=Codes.UNAUTHORIZED) + return user, attributes def register_servlets(hs, http_server): @@ -466,5 +423,3 @@ def register_servlets(hs, http_server): if hs.config.cas_enabled: CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) - CasRestServlet(hs).register(http_server) - # TODO PasswordResetRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 02d837ee6a..6bb4821ec6 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -128,11 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet): # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rawrules = yield self.store.get_push_rules_for_user(user_id) + rules = yield self.store.get_push_rules_for_user(user_id) - enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) - - rules = format_push_rules_for_user(requester.user, rawrules, enabled_map) + rules = format_push_rules_for_user(requester.user, rules) path = request.postpath[1:] diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index ab928a16da..9a2ed6ed88 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -17,7 +17,11 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.push import PusherConfigException -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import ( + parse_json_object_from_request, parse_string, RestServlet +) +from synapse.http.server import finish_request +from synapse.api.errors import StoreError from .base import ClientV1RestServlet, client_path_patterns @@ -136,6 +140,57 @@ class PushersSetRestServlet(ClientV1RestServlet): return 200, {} +class PushersRemoveRestServlet(RestServlet): + """ + To allow pusher to be delete by clicking a link (ie. GET request) + """ + PATTERNS = client_path_patterns("/pushers/remove$") + SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>" + + def __init__(self, hs): + super(RestServlet, self).__init__() + self.hs = hs + self.notifier = hs.get_notifier() + self.auth = hs.get_v1auth() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request, rights="delete_pusher") + user = requester.user + + app_id = parse_string(request, "app_id", required=True) + pushkey = parse_string(request, "pushkey", required=True) + + pusher_pool = self.hs.get_pusherpool() + + try: + yield pusher_pool.remove_pusher( + app_id=app_id, + pushkey=pushkey, + user_id=user.to_string(), + ) + except StoreError as se: + if se.code != 404: + # This is fine: they're already unsubscribed + raise + + self.notifier.on_new_replication_data() + + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Server", self.hs.version_string) + request.setHeader(b"Content-Length", b"%d" % ( + len(PushersRemoveRestServlet.SUCCESS_HTML), + )) + request.write(PushersRemoveRestServlet.SUCCESS_HTML) + finish_request(request) + defer.returnValue(None) + + def on_OPTIONS(self, _): + return 200, {} + + def register_servlets(hs, http_server): PushersRestServlet(hs).register(http_server) PushersSetRestServlet(hs).register(http_server) + PushersRemoveRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index e3f4fbb0bb..2383b9df86 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -52,6 +52,10 @@ class RegisterRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False) def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(RegisterRestServlet, self).__init__(hs) # sessions are stored as: # self.sessions = { @@ -60,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet): # TODO: persistent storage self.sessions = {} self.enable_registration = hs.config.enable_registration + self.auth_handler = hs.get_auth_handler() def on_GET(self, request): if self.hs.config.enable_registration_captcha: @@ -299,9 +304,10 @@ class RegisterRestServlet(ClientV1RestServlet): user_localpart = register_json["user"].encode("utf-8") handler = self.handlers.registration_handler - (user_id, token) = yield handler.appservice_register( + user_id = yield handler.appservice_register( user_localpart, as_token ) + token = yield self.auth_handler.issue_access_token(user_id) self._remove_session(session) defer.returnValue({ "user_id": user_id, @@ -324,6 +330,14 @@ class RegisterRestServlet(ClientV1RestServlet): raise SynapseError(400, "Shared secret registration is not enabled") user = register_json["user"].encode("utf-8") + password = register_json["password"].encode("utf-8") + admin = register_json.get("admin", None) + + # Its important to check as we use null bytes as HMAC field separators + if "\x00" in user: + raise SynapseError(400, "Invalid user") + if "\x00" in password: + raise SynapseError(400, "Invalid password") # str() because otherwise hmac complains that 'unicode' does not # have the buffer interface @@ -331,17 +345,21 @@ class RegisterRestServlet(ClientV1RestServlet): want_mac = hmac.new( key=self.hs.config.registration_shared_secret, - msg=user, digestmod=sha1, - ).hexdigest() - - password = register_json["password"].encode("utf-8") + ) + want_mac.update(user) + want_mac.update("\x00") + want_mac.update(password) + want_mac.update("\x00") + want_mac.update("admin" if admin else "notadmin") + want_mac = want_mac.hexdigest() if compare_digest(want_mac, got_mac): handler = self.handlers.registration_handler user_id, token = yield handler.register( localpart=user, password=password, + admin=bool(admin), ) self._remove_session(session) defer.returnValue({ @@ -410,12 +428,15 @@ class CreateUserRestServlet(ClientV1RestServlet): raise SynapseError(400, "Failed to parse 'duration_seconds'") if duration_seconds > self.direct_user_creation_max_duration: duration_seconds = self.direct_user_creation_max_duration + password_hash = user_json["password_hash"].encode("utf-8") \ + if user_json.get("password_hash") else None handler = self.handlers.registration_handler user_id, token = yield handler.get_or_create_user( localpart=localpart, displayname=displayname, - duration_seconds=duration_seconds + duration_in_ms=(duration_seconds * 1000), + password_hash=password_hash ) defer.returnValue({ diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 644aa4e513..866a1e9120 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -20,12 +20,14 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.api.errors import SynapseError, Codes, AuthError from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership +from synapse.api.filtering import Filter from synapse.types import UserID, RoomID, RoomAlias from synapse.events.utils import serialize_event from synapse.http.servlet import parse_json_object_from_request import logging import urllib +import ujson as json logger = logging.getLogger(__name__) @@ -72,8 +74,6 @@ class RoomCreateRestServlet(ClientV1RestServlet): def get_room_config(self, request): user_supplied_config = parse_json_object_from_request(request) - # default visibility - user_supplied_config.setdefault("visibility", "public") return user_supplied_config def on_OPTIONS(self, request): @@ -279,8 +279,16 @@ class PublicRoomListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - handler = self.handlers.room_list_handler - data = yield handler.get_public_room_list() + try: + yield self.auth.get_user_by_req(request) + except AuthError: + # This endpoint isn't authed, but its useful to know who's hitting + # it if they *do* supply an access token + pass + + handler = self.hs.get_room_list_handler() + data = yield handler.get_aggregated_public_room_list() + defer.returnValue((200, data)) @@ -321,12 +329,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet): request, default_limit=10, ) as_client_event = "raw" not in request.args + filter_bytes = request.args.get("filter", None) + if filter_bytes: + filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8") + event_filter = Filter(json.loads(filter_json)) + else: + event_filter = None handler = self.handlers.message_handler msgs = yield handler.get_messages( room_id=room_id, requester=requester, pagin_config=pagination_config, - as_client_event=as_client_event + as_client_event=as_client_event, + event_filter=event_filter, ) defer.returnValue((200, msgs)) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index b6faa2b0e6..20e765f48f 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -25,7 +25,9 @@ import logging logger = logging.getLogger(__name__) -def client_v2_patterns(path_regex, releases=(0,)): +def client_v2_patterns(path_regex, releases=(0,), + v2_alpha=True, + unstable=True): """Creates a regex compiled client path with the correct client path prefix. @@ -35,9 +37,12 @@ def client_v2_patterns(path_regex, releases=(0,)): Returns: SRE_Pattern """ - patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)] - unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") - patterns.append(re.compile("^" + unstable_prefix + path_regex)) + patterns = [] + if v2_alpha: + patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)) + if unstable: + unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") + patterns.append(re.compile("^" + unstable_prefix + path_regex)) for release in releases: new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) patterns.append(re.compile("^" + new_prefix + path_regex)) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index c88c270537..eb49ad62e9 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -28,14 +28,46 @@ import logging logger = logging.getLogger(__name__) +class PasswordRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/password/email/requestToken$") + + def __init__(self, hs): + super(PasswordRequestTokenRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + required = ['id_server', 'client_secret', 'email', 'send_attempt'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if absent: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + 'email', body['email'] + ) + + if existingUid is None: + raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) + + ret = yield self.identity_handler.requestEmailToken(**body) + defer.returnValue((200, ret)) + + class PasswordRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/password") + PATTERNS = client_v2_patterns("/account/password$") def __init__(self, hs): super(PasswordRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() @defer.inlineCallbacks def on_POST(self, request): @@ -89,15 +121,90 @@ class PasswordRestServlet(RestServlet): return 200, {} +class DeactivateAccountRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/deactivate$") + + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + super(DeactivateAccountRestServlet, self).__init__() + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + authed, result, params, _ = yield self.auth_handler.check_auth([ + [LoginType.PASSWORD], + ], body, self.hs.get_ip_from_request(request)) + + if not authed: + defer.returnValue((401, result)) + + user_id = None + requester = None + + if LoginType.PASSWORD in result: + # if using password, they should also be logged in + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + if user_id != result[LoginType.PASSWORD]: + raise LoginError(400, "", Codes.UNKNOWN) + else: + logger.error("Auth succeeded but no known type!", result.keys()) + raise SynapseError(500, "", Codes.UNKNOWN) + + # FIXME: Theoretically there is a race here wherein user resets password + # using threepid. + yield self.store.user_delete_access_tokens(user_id) + yield self.store.user_delete_threepids(user_id) + yield self.store.user_set_password_hash(user_id, None) + + defer.returnValue((200, {})) + + +class ThreepidRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") + + def __init__(self, hs): + self.hs = hs + super(ThreepidRequestTokenRestServlet, self).__init__() + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + required = ['id_server', 'client_secret', 'email', 'send_attempt'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if absent: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + 'email', body['email'] + ) + + if existingUid is not None: + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + + ret = yield self.identity_handler.requestEmailToken(**body) + defer.returnValue((200, ret)) + + class ThreepidRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/3pid") + PATTERNS = client_v2_patterns("/account/3pid$") def __init__(self, hs): super(ThreepidRestServlet, self).__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() @defer.inlineCallbacks def on_GET(self, request): @@ -157,5 +264,8 @@ class ThreepidRestServlet(RestServlet): def register_servlets(hs, http_server): + PasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) + DeactivateAccountRestServlet(hs).register(http_server) + ThreepidRequestTokenRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 78181b7b18..58d3cad6a1 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet): super(AuthRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler @defer.inlineCallbacks diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py new file mode 100644 index 0000000000..8fbd3d3dfc --- /dev/null +++ b/synapse/rest/client/v2_alpha/devices.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.http import servlet +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class DevicesRestServlet(servlet.RestServlet): + PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(DevicesRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + devices = yield self.device_handler.get_devices_by_user( + requester.user.to_string() + ) + defer.returnValue((200, {"devices": devices})) + + +class DeviceRestServlet(servlet.RestServlet): + PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", + releases=[], v2_alpha=False) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(DeviceRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def on_GET(self, request, device_id): + requester = yield self.auth.get_user_by_req(request) + device = yield self.device_handler.get_device( + requester.user.to_string(), + device_id, + ) + defer.returnValue((200, device)) + + @defer.inlineCallbacks + def on_DELETE(self, request, device_id): + # XXX: it's not completely obvious we want to expose this endpoint. + # It allows the client to delete access tokens, which feels like a + # thing which merits extra auth. But if we want to do the interactive- + # auth dance, we should really make it possible to delete more than one + # device at a time. + requester = yield self.auth.get_user_by_req(request) + yield self.device_handler.delete_device( + requester.user.to_string(), + device_id, + ) + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_PUT(self, request, device_id): + requester = yield self.auth.get_user_by_req(request) + + body = servlet.parse_json_object_from_request(request) + yield self.device_handler.update_device( + requester.user.to_string(), + device_id, + body + ) + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + DevicesRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 89ab39491c..c5ff16adf3 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -13,24 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + +import simplejson as json +from canonicaljson import encode_canonical_json from twisted.internet import defer +import synapse.api.errors +import synapse.server +import synapse.types from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID - -from canonicaljson import encode_canonical_json - from ._base import client_v2_patterns -import logging -import simplejson as json - logger = logging.getLogger(__name__) class KeyUploadServlet(RestServlet): """ - POST /keys/upload/<device_id> HTTP/1.1 + POST /keys/upload HTTP/1.1 Content-Type: application/json { @@ -53,23 +54,45 @@ class KeyUploadServlet(RestServlet): }, } """ - PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=()) + PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$", + releases=()) def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(KeyUploadServlet, self).__init__() self.store = hs.get_datastore() self.clock = hs.get_clock() self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def on_POST(self, request, device_id): requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() - # TODO: Check that the device_id matches that in the authentication - # or derive the device_id from the authentication instead. body = parse_json_object_from_request(request) + if device_id is not None: + # passing the device_id here is deprecated; however, we allow it + # for now for compatibility with older clients. + if (requester.device_id is not None and + device_id != requester.device_id): + logger.warning("Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, device_id) + else: + device_id = requester.device_id + + if device_id is None: + raise synapse.api.errors.SynapseError( + 400, + "To upload keys, you must pass device_id when authenticating" + ) + time_now = self.clock.time_msec() # TODO: Validate the JSON to make sure it has the right keys. @@ -102,13 +125,12 @@ class KeyUploadServlet(RestServlet): user_id, device_id, time_now, key_list ) - result = yield self.store.count_e2e_one_time_keys(user_id, device_id) - defer.returnValue((200, {"one_time_key_counts": result})) - - @defer.inlineCallbacks - def on_GET(self, request, device_id): - requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + # the device should have been registered already, but it may have been + # deleted due to a race with a DELETE request. Or we may be using an + # old access_token without an associated device_id. Either way, we + # need to double-check the device is registered to avoid ending up with + # keys without a corresponding device. + self.device_handler.check_device_registered(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id) defer.returnValue((200, {"one_time_key_counts": result})) @@ -162,17 +184,19 @@ class KeyQueryServlet(RestServlet): ) def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ super(KeyQueryServlet, self).__init__() - self.store = hs.get_datastore() self.auth = hs.get_auth() - self.federation = hs.get_replication_layer() - self.is_mine = hs.is_mine + self.e2e_keys_handler = hs.get_e2e_keys_handler() @defer.inlineCallbacks def on_POST(self, request, user_id, device_id): yield self.auth.get_user_by_req(request) body = parse_json_object_from_request(request) - result = yield self.handle_request(body) + result = yield self.e2e_keys_handler.query_devices(body) defer.returnValue(result) @defer.inlineCallbacks @@ -181,45 +205,11 @@ class KeyQueryServlet(RestServlet): auth_user_id = requester.user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] - result = yield self.handle_request( + result = yield self.e2e_keys_handler.query_devices( {"device_keys": {user_id: device_ids}} ) defer.returnValue(result) - @defer.inlineCallbacks - def handle_request(self, body): - local_query = [] - remote_queries = {} - for user_id, device_ids in body.get("device_keys", {}).items(): - user = UserID.from_string(user_id) - if self.is_mine(user): - if not device_ids: - local_query.append((user_id, None)) - else: - for device_id in device_ids: - local_query.append((user_id, device_id)) - else: - remote_queries.setdefault(user.domain, {})[user_id] = list( - device_ids - ) - results = yield self.store.get_e2e_device_keys(local_query) - - json_result = {} - for user_id, device_keys in results.items(): - for device_id, json_bytes in device_keys.items(): - json_result.setdefault(user_id, {})[device_id] = json.loads( - json_bytes - ) - - for destination, device_keys in remote_queries.items(): - remote_result = yield self.federation.query_client_keys( - destination, {"device_keys": device_keys} - ) - for user_id, keys in remote_result["device_keys"].items(): - if user_id in device_keys: - json_result[user_id] = keys - defer.returnValue((200, {"device_keys": json_result})) - class OneTimeKeyServlet(RestServlet): """ diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 1ecc02d94d..943f5676a3 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -41,17 +41,59 @@ else: logger = logging.getLogger(__name__) +class RegisterRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/register/email/requestToken$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(RegisterRequestTokenRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + required = ['id_server', 'client_secret', 'email', 'send_attempt'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if len(absent) > 0: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + 'email', body['email'] + ) + + if existingUid is not None: + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + + ret = yield self.identity_handler.requestEmailToken(**body) + defer.returnValue((200, ret)) + + class RegisterRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/register") + PATTERNS = client_v2_patterns("/register$") def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(RegisterRestServlet, self).__init__() + self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def on_POST(self, request): @@ -70,10 +112,6 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind,) ) - if '/register/email/requestToken' in request.path: - ret = yield self.onEmailTokenRequest(request) - defer.returnValue(ret) - body = parse_json_object_from_request(request) # we do basic sanity checks here because the auth layer will store these @@ -104,11 +142,12 @@ class RegisterRestServlet(RestServlet): # Set the desired user according to the AS API (which uses the # 'user' key not 'username'). Since this is a new addition, we'll # fallback to 'username' if they gave one. - if isinstance(body.get("user"), basestring): - desired_username = body["user"] - result = yield self._do_appservice_registration( - desired_username, request.args["access_token"][0] - ) + desired_username = body.get("user", desired_username) + + if isinstance(desired_username, basestring): + result = yield self._do_appservice_registration( + desired_username, request.args["access_token"][0], body + ) defer.returnValue((200, result)) # we throw for non 200 responses return @@ -117,7 +156,7 @@ class RegisterRestServlet(RestServlet): # FIXME: Should we really be determining if this is shared secret # auth based purely on the 'mac' key? result = yield self._do_shared_secret_registration( - desired_username, desired_password, body["mac"] + desired_username, desired_password, body ) defer.returnValue((200, result)) # we throw for non 200 responses return @@ -157,12 +196,12 @@ class RegisterRestServlet(RestServlet): [LoginType.EMAIL_IDENTITY] ] - authed, result, params, session_id = yield self.auth_handler.check_auth( + authed, auth_result, params, session_id = yield self.auth_handler.check_auth( flows, body, self.hs.get_ip_from_request(request) ) if not authed: - defer.returnValue((401, result)) + defer.returnValue((401, auth_result)) return if registered_user_id is not None: @@ -170,106 +209,58 @@ class RegisterRestServlet(RestServlet): "Already registered user ID %r for this session", registered_user_id ) - access_token = yield self.auth_handler.issue_access_token(registered_user_id) - refresh_token = yield self.auth_handler.issue_refresh_token( - registered_user_id + # don't re-register the email address + add_email = False + else: + # NB: This may be from the auth handler and NOT from the POST + if 'password' not in params: + raise SynapseError(400, "Missing password.", + Codes.MISSING_PARAM) + + desired_username = params.get("username", None) + new_password = params.get("password", None) + guest_access_token = params.get("guest_access_token", None) + + (registered_user_id, _) = yield self.registration_handler.register( + localpart=desired_username, + password=new_password, + guest_access_token=guest_access_token, + generate_token=False, ) - defer.returnValue((200, { - "user_id": registered_user_id, - "access_token": access_token, - "home_server": self.hs.hostname, - "refresh_token": refresh_token, - })) - - # NB: This may be from the auth handler and NOT from the POST - if 'password' not in params: - raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM) - - desired_username = params.get("username", None) - new_password = params.get("password", None) - guest_access_token = params.get("guest_access_token", None) - - (user_id, token) = yield self.registration_handler.register( - localpart=desired_username, - password=new_password, - guest_access_token=guest_access_token, - ) - # remember that we've now registered that user account, and with what - # user ID (since the user may not have specified) - self.auth_handler.set_session_data( - session_id, "registered_user_id", user_id + # remember that we've now registered that user account, and with + # what user ID (since the user may not have specified) + self.auth_handler.set_session_data( + session_id, "registered_user_id", registered_user_id + ) + + add_email = True + + return_dict = yield self._create_registration_details( + registered_user_id, params ) - if result and LoginType.EMAIL_IDENTITY in result: - threepid = result[LoginType.EMAIL_IDENTITY] - - for reqd in ['medium', 'address', 'validated_at']: - if reqd not in threepid: - logger.info("Can't add incomplete 3pid") - else: - yield self.auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], - ) - - # And we add an email pusher for them by default, but only - # if email notifications are enabled (so people don't start - # getting mail spam where they weren't before if email - # notifs are set up on a home server) - if ( - self.hs.config.email_enable_notifs and - self.hs.config.email_notif_for_new_users - ): - # Pull the ID of the access token back out of the db - # It would really make more sense for this to be passed - # up when the access token is saved, but that's quite an - # invasive change I'd rather do separately. - user_tuple = yield self.store.get_user_by_access_token( - token - ) - - yield self.hs.get_pusherpool().add_pusher( - user_id=user_id, - access_token=user_tuple["token_id"], - kind="email", - app_id="m.email", - app_display_name="Email Notifications", - device_display_name=threepid["address"], - pushkey=threepid["address"], - lang=None, # We don't know a user's language here - data={}, - ) - - if 'bind_email' in params and params['bind_email']: - logger.info("bind_email specified: binding") - - emailThreepid = result[LoginType.EMAIL_IDENTITY] - threepid_creds = emailThreepid['threepid_creds'] - logger.debug("Binding emails %s to %s" % ( - emailThreepid, user_id - )) - yield self.identity_handler.bind_threepid(threepid_creds, user_id) - else: - logger.info("bind_email not specified: not binding email") - - result = yield self._create_registration_details(user_id, token) - defer.returnValue((200, result)) + if add_email and auth_result and LoginType.EMAIL_IDENTITY in auth_result: + threepid = auth_result[LoginType.EMAIL_IDENTITY] + yield self._register_email_threepid( + registered_user_id, threepid, return_dict["access_token"], + params.get("bind_email") + ) + + defer.returnValue((200, return_dict)) def on_OPTIONS(self, _): return 200, {} @defer.inlineCallbacks - def _do_appservice_registration(self, username, as_token): - (user_id, token) = yield self.registration_handler.appservice_register( + def _do_appservice_registration(self, username, as_token, body): + user_id = yield self.registration_handler.appservice_register( username, as_token ) - defer.returnValue((yield self._create_registration_details(user_id, token))) + defer.returnValue((yield self._create_registration_details(user_id, body))) @defer.inlineCallbacks - def _do_shared_secret_registration(self, username, password, mac): + def _do_shared_secret_registration(self, username, password, body): if not self.hs.config.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") @@ -277,7 +268,7 @@ class RegisterRestServlet(RestServlet): # str() because otherwise hmac complains that 'unicode' does not # have the buffer interface - got_mac = str(mac) + got_mac = str(body["mac"]) want_mac = hmac.new( key=self.hs.config.registration_shared_secret, @@ -290,43 +281,132 @@ class RegisterRestServlet(RestServlet): 403, "HMAC incorrect", ) - (user_id, token) = yield self.registration_handler.register( - localpart=username, password=password + (user_id, _) = yield self.registration_handler.register( + localpart=username, password=password, generate_token=False, ) - defer.returnValue((yield self._create_registration_details(user_id, token))) - @defer.inlineCallbacks - def _create_registration_details(self, user_id, token): - refresh_token = yield self.auth_handler.issue_refresh_token(user_id) - defer.returnValue({ - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname, - "refresh_token": refresh_token, - }) + result = yield self._create_registration_details(user_id, body) + defer.returnValue(result) @defer.inlineCallbacks - def onEmailTokenRequest(self, request): - body = parse_json_object_from_request(request) + def _register_email_threepid(self, user_id, threepid, token, bind_email): + """Add an email address as a 3pid identifier + + Also adds an email pusher for the email address, if configured in the + HS config + + Also optionally binds emails to the given user_id on the identity server + + Args: + user_id (str): id of user + threepid (object): m.login.email.identity auth response + token (str): access_token for the user + bind_email (bool): true if the client requested the email to be + bound at the identity server + Returns: + defer.Deferred: + """ + reqd = ('medium', 'address', 'validated_at') + if any(x not in threepid for x in reqd): + logger.info("Can't add incomplete 3pid") + defer.returnValue() + + yield self.auth_handler.add_threepid( + user_id, + threepid['medium'], + threepid['address'], + threepid['validated_at'], + ) - required = ['id_server', 'client_secret', 'email', 'send_attempt'] - absent = [] - for k in required: - if k not in body: - absent.append(k) + # And we add an email pusher for them by default, but only + # if email notifications are enabled (so people don't start + # getting mail spam where they weren't before if email + # notifs are set up on a home server) + if (self.hs.config.email_enable_notifs and + self.hs.config.email_notif_for_new_users): + # Pull the ID of the access token back out of the db + # It would really make more sense for this to be passed + # up when the access token is saved, but that's quite an + # invasive change I'd rather do separately. + user_tuple = yield self.store.get_user_by_access_token( + token + ) + token_id = user_tuple["token_id"] + + yield self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="email", + app_id="m.email", + app_display_name="Email Notifications", + device_display_name=threepid["address"], + pushkey=threepid["address"], + lang=None, # We don't know a user's language here + data={}, + ) - if len(absent) > 0: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + if bind_email: + logger.info("bind_email specified: binding") + logger.debug("Binding emails %s to %s" % ( + threepid, user_id + )) + yield self.identity_handler.bind_threepid( + threepid['threepid_creds'], user_id + ) + else: + logger.info("bind_email not specified: not binding email") - existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'email', body['email'] + @defer.inlineCallbacks + def _create_registration_details(self, user_id, params): + """Complete registration of newly-registered user + + Allocates device_id if one was not given; also creates access_token + and refresh_token. + + Args: + (str) user_id: full canonical @user:id + (object) params: registration parameters, from which we pull + device_id and initial_device_name + Returns: + defer.Deferred: (object) dictionary for response from /register + """ + device_id = yield self._register_device(user_id, params) + + access_token, refresh_token = ( + yield self.auth_handler.get_login_tuple_for_user_id( + user_id, device_id=device_id, + initial_display_name=params.get("initial_device_display_name") + ) ) - if existingUid is not None: - raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + defer.returnValue({ + "user_id": user_id, + "access_token": access_token, + "home_server": self.hs.hostname, + "refresh_token": refresh_token, + "device_id": device_id, + }) - ret = yield self.identity_handler.requestEmailToken(**body) - defer.returnValue((200, ret)) + def _register_device(self, user_id, params): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) params: registration parameters, from which we pull + device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + # register the user's device + device_id = params.get("device_id") + initial_display_name = params.get("initial_device_display_name") + device_id = self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + return device_id @defer.inlineCallbacks def _do_guest_registration(self): @@ -336,7 +416,11 @@ class RegisterRestServlet(RestServlet): generate_token=False, make_guest=True ) - access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"]) + access_token = self.auth_handler.generate_access_token( + user_id, ["guest = true"] + ) + # XXX the "guest" caveat is not copied by /tokenrefresh. That's ok + # so long as we don't return a refresh_token here. defer.returnValue((200, { "user_id": user_id, "access_token": access_token, @@ -345,4 +429,5 @@ class RegisterRestServlet(RestServlet): def register_servlets(hs, http_server): + RegisterRequestTokenRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index a158c2209a..0d312c91d4 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -38,10 +38,14 @@ class TokenRefreshRestServlet(RestServlet): body = parse_json_object_from_request(request) try: old_refresh_token = body["refresh_token"] - auth_handler = self.hs.get_handlers().auth_handler - (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( - old_refresh_token, auth_handler.generate_refresh_token) - new_access_token = yield auth_handler.issue_access_token(user_id) + auth_handler = self.hs.get_auth_handler() + refresh_result = yield self.store.exchange_refresh_token( + old_refresh_token, auth_handler.generate_refresh_token + ) + (user_id, new_refresh_token, device_id) = refresh_result + new_access_token = yield auth_handler.issue_access_token( + user_id, device_id + ) defer.returnValue((200, { "access_token": new_access_token, "refresh_token": new_refresh_token, diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index ca5468c402..e984ea47db 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -26,7 +26,11 @@ class VersionsRestServlet(RestServlet): def on_GET(self, request): return (200, { - "versions": ["r0.0.1"] + "versions": [ + "r0.0.1", + "r0.1.0", + "r0.2.0", + ] }) |