diff options
author | Mark Haines <mjark@negativecurvature.net> | 2015-04-28 11:00:27 +0100 |
---|---|---|
committer | Mark Haines <mjark@negativecurvature.net> | 2015-04-28 11:00:27 +0100 |
commit | 9182f876645a27eb9599c99963876b12067fe93a (patch) | |
tree | 158a1e1213f73f2395389d5a861b5eb07f8eb36f | |
parent | Merge pull request #133 from matrix-org/invite_power_level (diff) | |
parent | Add commentage. (diff) | |
download | synapse-9182f876645a27eb9599c99963876b12067fe93a.tar.xz |
Merge pull request #126 from matrix-org/csauth
Client / Server Auth Refactor
34 files changed, 1323 insertions, 287 deletions
diff --git a/CAPTCHA_SETUP b/CAPTCHA_SETUP new file mode 100644 index 0000000000..75ff80981b --- /dev/null +++ b/CAPTCHA_SETUP @@ -0,0 +1,31 @@ +Captcha can be enabled for this home server. This file explains how to do that. +The captcha mechanism used is Google's ReCaptcha. This requires API keys from Google. + +Getting keys +------------ +Requires a public/private key pair from: + +https://developers.google.com/recaptcha/ + + +Setting ReCaptcha Keys +---------------------- +The keys are a config option on the home server config. If they are not +visible, you can generate them via --generate-config. Set the following value: + + recaptcha_public_key: YOUR_PUBLIC_KEY + recaptcha_private_key: YOUR_PRIVATE_KEY + +In addition, you MUST enable captchas via: + + enable_registration_captcha: true + +Configuring IP used for auth +---------------------------- +The ReCaptcha API requires that the IP address of the user who solved the +captcha is sent. If the client is connecting through a proxy or load balancer, +it may be required to use the X-Forwarded-For (XFF) header instead of the origin +IP address. This can be configured as an option on the home server like so: + + captcha_ip_origin_is_x_forwarded: true + diff --git a/register_new_matrix_user b/register_new_matrix_user index 4a520bdb5d..0ca83795a3 100755 --- a/register_new_matrix_user +++ b/register_new_matrix_user @@ -33,10 +33,9 @@ def request_registration(user, password, server_location, shared_secret): ).hexdigest() data = { - "user": user, + "username": user, "password": password, "mac": mac, - "type": "org.matrix.login.shared_secret", } server_location = server_location.rstrip("/") @@ -44,7 +43,7 @@ def request_registration(user, password, server_location, shared_secret): print "Sending registration request..." req = urllib2.Request( - "%s/_matrix/client/api/v1/register" % (server_location,), + "%s/_matrix/client/v2_alpha/register" % (server_location,), data=json.dumps(data), headers={'Content-Type': 'application/json'} ) diff --git a/static/client/register/style.css b/static/client/register/style.css index a3398852b9..5a7b6eebf2 100644 --- a/static/client/register/style.css +++ b/static/client/register/style.css @@ -37,9 +37,13 @@ textarea, input { margin: auto } +.g-recaptcha div { + margin: auto; +} + #registrationForm { text-align: left; - padding: 1em; + padding: 5px; margin-bottom: 40px; display: inline-block; diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 935dffbabe..77322a5c10 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, JoinRules -from synapse.api.errors import AuthError, StoreError, Codes, SynapseError +from synapse.api.errors import AuthError, Codes, SynapseError from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.types import UserID, ClientInfo @@ -40,6 +40,7 @@ class Auth(object): self.hs = hs self.store = hs.get_datastore() self.state = hs.get_state_handler() + self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 def check(self, event, auth_events): """ Checks if this event is correctly authed. @@ -369,7 +370,10 @@ class Auth(object): defer.returnValue((user, ClientInfo(device_id, token_id))) except KeyError: - raise AuthError(403, "Missing access token.") + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", + errcode=Codes.MISSING_TOKEN + ) @defer.inlineCallbacks def get_user_by_token(self, token): @@ -383,21 +387,20 @@ class Auth(object): Raises: AuthError if no user by that token exists or the token is invalid. """ - try: - ret = yield self.store.get_user_by_token(token) - if not ret: - raise StoreError(400, "Unknown token") - user_info = { - "admin": bool(ret.get("admin", False)), - "device_id": ret.get("device_id"), - "user": UserID.from_string(ret.get("name")), - "token_id": ret.get("token_id", None), - } + ret = yield self.store.get_user_by_token(token) + if not ret: + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", + errcode=Codes.UNKNOWN_TOKEN + ) + user_info = { + "admin": bool(ret.get("admin", False)), + "device_id": ret.get("device_id"), + "user": UserID.from_string(ret.get("name")), + "token_id": ret.get("token_id", None), + } - defer.returnValue(user_info) - except StoreError: - raise AuthError(403, "Unrecognised access token.", - errcode=Codes.UNKNOWN_TOKEN) + defer.returnValue(user_info) @defer.inlineCallbacks def get_appservice_by_req(self, request): @@ -405,11 +408,16 @@ class Auth(object): token = request.args["access_token"][0] service = yield self.store.get_app_service_by_token(token) if not service: - raise AuthError(403, "Unrecognised access token.", - errcode=Codes.UNKNOWN_TOKEN) + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Unrecognised access token.", + errcode=Codes.UNKNOWN_TOKEN + ) defer.returnValue(service) except KeyError: - raise AuthError(403, "Missing access token.") + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token." + ) def is_server_admin(self, user): return self.store.is_server_admin(user) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index b16bf4247d..d8a18ee87b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -59,6 +59,9 @@ class LoginType(object): EMAIL_URL = u"m.login.email.url" EMAIL_IDENTITY = u"m.login.email.identity" RECAPTCHA = u"m.login.recaptcha" + DUMMY = u"m.login.dummy" + + # Only for C/S API v1 APPLICATION_SERVICE = u"m.login.application_service" SHARED_SECRET = u"org.matrix.login.shared_secret" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 72d2bd5b4c..0b3320e62c 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -31,6 +31,7 @@ class Codes(object): BAD_PAGINATION = "M_BAD_PAGINATION" UNKNOWN = "M_UNKNOWN" NOT_FOUND = "M_NOT_FOUND" + MISSING_TOKEN = "M_MISSING_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" @@ -38,6 +39,7 @@ class Codes(object): MISSING_PARAM = "M_MISSING_PARAM" TOO_LARGE = "M_TOO_LARGE" EXCLUSIVE = "M_EXCLUSIVE" + THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" class CodeMessageException(RuntimeError): diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index 7e21c7414d..07fbfadc0f 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -20,6 +20,7 @@ class CaptchaConfig(Config): def __init__(self, args): super(CaptchaConfig, self).__init__(args) self.recaptcha_private_key = args.recaptcha_private_key + self.recaptcha_public_key = args.recaptcha_public_key self.enable_registration_captcha = args.enable_registration_captcha self.captcha_ip_origin_is_x_forwarded = ( args.captcha_ip_origin_is_x_forwarded @@ -31,8 +32,12 @@ class CaptchaConfig(Config): super(CaptchaConfig, cls).add_arguments(parser) group = parser.add_argument_group("recaptcha") group.add_argument( + "--recaptcha-public-key", type=str, default="YOUR_PUBLIC_KEY", + help="This Home Server's ReCAPTCHA public key." + ) + group.add_argument( "--recaptcha-private-key", type=str, default="YOUR_PRIVATE_KEY", - help="The matching private key for the web client's public key." + help="This Home Server's ReCAPTCHA private key." ) group.add_argument( "--enable-registration-captcha", type=bool, default=False, diff --git a/synapse/config/email.py b/synapse/config/email.py deleted file mode 100644 index f0854f8c37..0000000000 --- a/synapse/config/email.py +++ /dev/null @@ -1,42 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ._base import Config - - -class EmailConfig(Config): - - def __init__(self, args): - super(EmailConfig, self).__init__(args) - self.email_from_address = args.email_from_address - self.email_smtp_server = args.email_smtp_server - - @classmethod - def add_arguments(cls, parser): - super(EmailConfig, cls).add_arguments(parser) - email_group = parser.add_argument_group("email") - email_group.add_argument( - "--email-from-address", - default="FROM@EXAMPLE.COM", - help="The address to send emails from (e.g. for password resets)." - ) - email_group.add_argument( - "--email-smtp-server", - default="", - help=( - "The SMTP server to send emails from (e.g. for password" - " resets)." - ) - ) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 3edfadb98b..efbdd93c25 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -20,7 +20,6 @@ from .database import DatabaseConfig from .ratelimiting import RatelimitConfig from .repository import ContentRepositoryConfig from .captcha import CaptchaConfig -from .email import EmailConfig from .voip import VoipConfig from .registration import RegistrationConfig from .metrics import MetricsConfig @@ -29,7 +28,7 @@ from .appservice import AppServiceConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, - EmailConfig, VoipConfig, RegistrationConfig, + VoipConfig, RegistrationConfig, MetricsConfig, AppServiceConfig,): pass diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 0c51d615ec..685792dbdc 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -30,6 +30,8 @@ from .typing import TypingNotificationHandler from .admin import AdminHandler from .appservice import ApplicationServicesHandler from .sync import SyncHandler +from .auth import AuthHandler +from .identity import IdentityHandler class Handlers(object): @@ -64,3 +66,5 @@ class Handlers(object): ) ) self.sync_handler = SyncHandler(hs) + self.auth_handler = AuthHandler(hs) + self.identity_handler = IdentityHandler(hs) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py new file mode 100644 index 0000000000..2e8009d3c3 --- /dev/null +++ b/synapse/handlers/auth.py @@ -0,0 +1,277 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from ._base import BaseHandler +from synapse.api.constants import LoginType +from synapse.types import UserID +from synapse.api.errors import LoginError, Codes +from synapse.http.client import SimpleHttpClient +from synapse.util.async import run_on_reactor + +from twisted.web.client import PartialDownloadError + +import logging +import bcrypt +import simplejson + +import synapse.util.stringutils as stringutils + + +logger = logging.getLogger(__name__) + + +class AuthHandler(BaseHandler): + + def __init__(self, hs): + super(AuthHandler, self).__init__(hs) + self.checkers = { + LoginType.PASSWORD: self._check_password_auth, + LoginType.RECAPTCHA: self._check_recaptcha, + LoginType.EMAIL_IDENTITY: self._check_email_identity, + LoginType.DUMMY: self._check_dummy_auth, + } + self.sessions = {} + + @defer.inlineCallbacks + def check_auth(self, flows, clientdict, clientip=None): + """ + Takes a dictionary sent by the client in the login / registration + protocol and handles the login flow. + + Args: + flows: list of list of stages + authdict: The dictionary from the client root level, not the + 'auth' key: this method prompts for auth if none is sent. + Returns: + A tuple of authed, dict, dict where authed is true if the client + has successfully completed an auth flow. If it is true, the first + dict contains the authenticated credentials of each stage. + + If authed is false, the first dictionary is the server response to + the login request and should be passed back to the client. + + In either case, the second dict contains the parameters for this + request (which may have been given only in a previous call). + """ + + authdict = None + sid = None + if clientdict and 'auth' in clientdict: + authdict = clientdict['auth'] + del clientdict['auth'] + if 'session' in authdict: + sid = authdict['session'] + sess = self._get_session_info(sid) + + if len(clientdict) > 0: + # This was designed to allow the client to omit the parameters + # and just supply the session in subsequent calls so it split + # auth between devices by just sharing the session, (eg. so you + # could continue registration from your phone having clicked the + # email auth link on there). It's probably too open to abuse + # because it lets unauthenticated clients store arbitrary objects + # on a home server. + # sess['clientdict'] = clientdict + # self._save_session(sess) + pass + elif 'clientdict' in sess: + clientdict = sess['clientdict'] + + if not authdict: + defer.returnValue( + (False, self._auth_dict_for_flows(flows, sess), clientdict) + ) + + if 'creds' not in sess: + sess['creds'] = {} + creds = sess['creds'] + + # check auth type currently being presented + if 'type' in authdict: + if authdict['type'] not in self.checkers: + raise LoginError(400, "", Codes.UNRECOGNIZED) + result = yield self.checkers[authdict['type']](authdict, clientip) + if result: + creds[authdict['type']] = result + self._save_session(sess) + + for f in flows: + if len(set(f) - set(creds.keys())) == 0: + logger.info("Auth completed with creds: %r", creds) + self._remove_session(sess) + defer.returnValue((True, creds, clientdict)) + + ret = self._auth_dict_for_flows(flows, sess) + ret['completed'] = creds.keys() + defer.returnValue((False, ret, clientdict)) + + @defer.inlineCallbacks + def add_oob_auth(self, stagetype, authdict, clientip): + """ + Adds the result of out-of-band authentication into an existing auth + session. Currently used for adding the result of fallback auth. + """ + if stagetype not in self.checkers: + raise LoginError(400, "", Codes.MISSING_PARAM) + if 'session' not in authdict: + raise LoginError(400, "", Codes.MISSING_PARAM) + + sess = self._get_session_info( + authdict['session'] + ) + if 'creds' not in sess: + sess['creds'] = {} + creds = sess['creds'] + + result = yield self.checkers[stagetype](authdict, clientip) + if result: + creds[stagetype] = result + self._save_session(sess) + defer.returnValue(True) + defer.returnValue(False) + + @defer.inlineCallbacks + def _check_password_auth(self, authdict, _): + if "user" not in authdict or "password" not in authdict: + raise LoginError(400, "", Codes.MISSING_PARAM) + + user = authdict["user"] + password = authdict["password"] + if not user.startswith('@'): + user = UserID.create(user, self.hs.hostname).to_string() + + user_info = yield self.store.get_user_by_id(user_id=user) + if not user_info: + logger.warn("Attempted to login as %s but they do not exist", user) + raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + + stored_hash = user_info[0]["password_hash"] + if bcrypt.checkpw(password, stored_hash): + defer.returnValue(user) + else: + logger.warn("Failed password login for user %s", user) + raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + + @defer.inlineCallbacks + def _check_recaptcha(self, authdict, clientip): + try: + user_response = authdict["response"] + except KeyError: + # Client tried to provide captcha but didn't give the parameter: + # bad request. + raise LoginError( + 400, "Captcha response is required", + errcode=Codes.CAPTCHA_NEEDED + ) + + logger.info( + "Submitting recaptcha response %s with remoteip %s", + user_response, clientip + ) + + # TODO: get this from the homeserver rather than creating a new one for + # each request + try: + client = SimpleHttpClient(self.hs) + data = yield client.post_urlencoded_get_json( + "https://www.google.com/recaptcha/api/siteverify", + args={ + 'secret': self.hs.config.recaptcha_private_key, + 'response': user_response, + 'remoteip': clientip, + } + ) + except PartialDownloadError as pde: + # Twisted is silly + data = pde.response + resp_body = simplejson.loads(data) + if 'success' in resp_body and resp_body['success']: + defer.returnValue(True) + raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + + @defer.inlineCallbacks + def _check_email_identity(self, authdict, _): + yield run_on_reactor() + + if 'threepid_creds' not in authdict: + raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) + + threepid_creds = authdict['threepid_creds'] + identity_handler = self.hs.get_handlers().identity_handler + + logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,)) + threepid = yield identity_handler.threepid_from_creds(threepid_creds) + + if not threepid: + raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + + threepid['threepid_creds'] = authdict['threepid_creds'] + + defer.returnValue(threepid) + + @defer.inlineCallbacks + def _check_dummy_auth(self, authdict, _): + yield run_on_reactor() + defer.returnValue(True) + + def _get_params_recaptcha(self): + return {"public_key": self.hs.config.recaptcha_public_key} + + def _auth_dict_for_flows(self, flows, session): + public_flows = [] + for f in flows: + public_flows.append(f) + + get_params = { + LoginType.RECAPTCHA: self._get_params_recaptcha, + } + + params = {} + + for f in public_flows: + for stage in f: + if stage in get_params and stage not in params: + params[stage] = get_params[stage]() + + return { + "session": session['id'], + "flows": [{"stages": f} for f in public_flows], + "params": params + } + + def _get_session_info(self, session_id): + if session_id not in self.sessions: + session_id = None + + if not session_id: + # create a new session + while session_id is None or session_id in self.sessions: + session_id = stringutils.random_string(24) + self.sessions[session_id] = { + "id": session_id, + } + + return self.sessions[session_id] + + def _save_session(self, session): + # TODO: Persistent storage + logger.debug("Saving session %s", session) + self.sessions[session["id"]] = session + + def _remove_session(self, session): + logger.debug("Removing session %s", session) + del self.sessions[session["id"]] diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py new file mode 100644 index 0000000000..ad8246b58c --- /dev/null +++ b/synapse/handlers/identity.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for interacting with Identity Servers""" +from twisted.internet import defer + +from synapse.api.errors import ( + CodeMessageException +) +from ._base import BaseHandler +from synapse.http.client import SimpleHttpClient +from synapse.util.async import run_on_reactor + +import json +import logging + +logger = logging.getLogger(__name__) + + +class IdentityHandler(BaseHandler): + + def __init__(self, hs): + super(IdentityHandler, self).__init__(hs) + + @defer.inlineCallbacks + def threepid_from_creds(self, creds): + yield run_on_reactor() + + # TODO: get this from the homeserver rather than creating a new one for + # each request + http_client = SimpleHttpClient(self.hs) + # XXX: make this configurable! + # trustedIdServers = ['matrix.org', 'localhost:8090'] + trustedIdServers = ['matrix.org'] + if not creds['id_server'] in trustedIdServers: + logger.warn('%s is not a trusted ID server: rejecting 3pid ' + + 'credentials', creds['id_server']) + defer.returnValue(None) + + data = {} + try: + data = yield http_client.get_json( + "https://%s%s" % ( + creds['id_server'], + "/_matrix/identity/api/v1/3pid/getValidated3pid" + ), + {'sid': creds['sid'], 'client_secret': creds['client_secret']} + ) + except CodeMessageException as e: + data = json.loads(e.msg) + + if 'medium' in data: + defer.returnValue(data) + defer.returnValue(None) + + @defer.inlineCallbacks + def bind_threepid(self, creds, mxid): + yield run_on_reactor() + logger.debug("binding threepid %r to %s", creds, mxid) + http_client = SimpleHttpClient(self.hs) + data = None + try: + data = yield http_client.post_urlencoded_get_json( + "https://%s%s" % ( + creds['id_server'], "/_matrix/identity/api/v1/3pid/bind" + ), + { + 'sid': creds['sid'], + 'client_secret': creds['client_secret'], + 'mxid': mxid, + } + ) + logger.debug("bound threepid %r to %s", creds, mxid) + except CodeMessageException as e: + data = json.loads(e.msg) + defer.returnValue(data) diff --git a/synapse/handlers/login.py b/synapse/handlers/login.py index 7447800460..f7f3698340 100644 --- a/synapse/handlers/login.py +++ b/synapse/handlers/login.py @@ -16,13 +16,9 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.api.errors import LoginError, Codes, CodeMessageException -from synapse.http.client import SimpleHttpClient -from synapse.util.emailutils import EmailException -import synapse.util.emailutils as emailutils +from synapse.api.errors import LoginError, Codes import bcrypt -import json import logging logger = logging.getLogger(__name__) @@ -69,48 +65,19 @@ class LoginHandler(BaseHandler): raise LoginError(403, "", errcode=Codes.FORBIDDEN) @defer.inlineCallbacks - def reset_password(self, user_id, email): - is_valid = yield self._check_valid_association(user_id, email) - logger.info("reset_password user=%s email=%s valid=%s", user_id, email, - is_valid) - if is_valid: - try: - # send an email out - emailutils.send_email( - smtp_server=self.hs.config.email_smtp_server, - from_addr=self.hs.config.email_from_address, - to_addr=email, - subject="Password Reset", - body="TODO." - ) - except EmailException as e: - logger.exception(e) + def set_password(self, user_id, newpassword, token_id=None): + password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) - @defer.inlineCallbacks - def _check_valid_association(self, user_id, email): - identity = yield self._query_email(email) - if identity and "mxid" in identity: - if identity["mxid"] == user_id: - defer.returnValue(True) - return - defer.returnValue(False) + yield self.store.user_set_password_hash(user_id, password_hash) + yield self.store.user_delete_access_tokens_apart_from(user_id, token_id) + yield self.hs.get_pusherpool().remove_pushers_by_user_access_token( + user_id, token_id + ) + yield self.store.flush_user(user_id) @defer.inlineCallbacks - def _query_email(self, email): - http_client = SimpleHttpClient(self.hs) - try: - data = yield http_client.get_json( - # TODO FIXME This should be configurable. - # XXX: ID servers need to use HTTPS - "http://%s%s" % ( - "matrix.org:8090", "/_matrix/identity/api/v1/lookup" - ), - { - 'medium': 'email', - 'address': email - } - ) - defer.returnValue(data) - except CodeMessageException as e: - data = json.loads(e.msg) - defer.returnValue(data) + def add_threepid(self, user_id, medium, address, validated_at): + yield self.store.user_add_threepid( + user_id, medium, address, validated_at, + self.hs.get_clock().time_msec() + ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c25e321099..7b68585a17 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -18,18 +18,15 @@ from twisted.internet import defer from synapse.types import UserID from synapse.api.errors import ( - AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError, - CodeMessageException + AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError ) from ._base import BaseHandler import synapse.util.stringutils as stringutils from synapse.util.async import run_on_reactor -from synapse.http.client import SimpleHttpClient from synapse.http.client import CaptchaServerHttpClient import base64 import bcrypt -import json import logging import urllib @@ -45,6 +42,30 @@ class RegistrationHandler(BaseHandler): self.distributor.declare("registered_user") @defer.inlineCallbacks + def check_username(self, localpart): + yield run_on_reactor() + + if urllib.quote(localpart) != localpart: + raise SynapseError( + 400, + "User ID must only contain characters which do not" + " require URL encoding." + ) + + user = UserID(localpart, self.hs.hostname) + user_id = user.to_string() + + yield self.check_user_id_is_valid(user_id) + + u = yield self.store.get_user_by_id(user_id) + if u: + raise SynapseError( + 400, + "User ID already taken.", + errcode=Codes.USER_IN_USE, + ) + + @defer.inlineCallbacks def register(self, localpart=None, password=None): """Registers a new client on the server. @@ -64,18 +85,11 @@ class RegistrationHandler(BaseHandler): password_hash = bcrypt.hashpw(password, bcrypt.gensalt()) if localpart: - if localpart and urllib.quote(localpart) != localpart: - raise SynapseError( - 400, - "User ID must only contain characters which do not" - " require URL encoding." - ) + yield self.check_username(localpart) user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - yield self.check_user_id_is_valid(user_id) - token = self._generate_token(user_id) yield self.store.register( user_id=user_id, @@ -157,7 +171,11 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def check_recaptcha(self, ip, private_key, challenge, response): - """Checks a recaptcha is correct.""" + """ + Checks a recaptcha is correct. + + Used only by c/s api v1 + """ captcha_response = yield self._validate_captcha( ip, @@ -176,13 +194,18 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def register_email(self, threepidCreds): - """Registers emails with an identity server.""" + """ + Registers emails with an identity server. + + Used only by c/s api v1 + """ for c in threepidCreds: logger.info("validating theeepidcred sid %s on id server %s", c['sid'], c['idServer']) try: - threepid = yield self._threepid_from_creds(c) + identity_handler = self.hs.get_handlers().identity_handler + threepid = yield identity_handler.threepid_from_creds(c) except: logger.exception("Couldn't validate 3pid") raise RegistrationError(400, "Couldn't validate 3pid") @@ -194,12 +217,16 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def bind_emails(self, user_id, threepidCreds): - """Links emails with a user ID and informs an identity server.""" + """Links emails with a user ID and informs an identity server. + + Used only by c/s api v1 + """ # Now we have a matrix ID, bind it to the threepids we were given for c in threepidCreds: + identity_handler = self.hs.get_handlers().identity_handler # XXX: This should be a deferred list, shouldn't it? - yield self._bind_threepid(c, user_id) + yield identity_handler.bind_threepid(c, user_id) @defer.inlineCallbacks def check_user_id_is_valid(self, user_id): @@ -227,61 +254,11 @@ class RegistrationHandler(BaseHandler): return "-" + stringutils.random_string(18) @defer.inlineCallbacks - def _threepid_from_creds(self, creds): - # TODO: get this from the homeserver rather than creating a new one for - # each request - http_client = SimpleHttpClient(self.hs) - # XXX: make this configurable! - trustedIdServers = ['matrix.org:8090', 'matrix.org'] - if not creds['idServer'] in trustedIdServers: - logger.warn('%s is not a trusted ID server: rejecting 3pid ' + - 'credentials', creds['idServer']) - defer.returnValue(None) - - data = {} - try: - data = yield http_client.get_json( - # XXX: This should be HTTPS - "http://%s%s" % ( - creds['idServer'], - "/_matrix/identity/api/v1/3pid/getValidated3pid" - ), - {'sid': creds['sid'], 'clientSecret': creds['clientSecret']} - ) - except CodeMessageException as e: - data = json.loads(e.msg) - - if 'medium' in data: - defer.returnValue(data) - defer.returnValue(None) - - @defer.inlineCallbacks - def _bind_threepid(self, creds, mxid): - yield - logger.debug("binding threepid") - http_client = SimpleHttpClient(self.hs) - data = None - try: - data = yield http_client.post_urlencoded_get_json( - # XXX: Change when ID servers are all HTTPS - "http://%s%s" % ( - creds['idServer'], "/_matrix/identity/api/v1/3pid/bind" - ), - { - 'sid': creds['sid'], - 'clientSecret': creds['clientSecret'], - 'mxid': mxid, - } - ) - logger.debug("bound threepid") - except CodeMessageException as e: - data = json.loads(e.msg) - defer.returnValue(data) - - @defer.inlineCallbacks def _validate_captcha(self, ip_addr, private_key, challenge, response): """Validates the captcha provided. + Used only by c/s api v1 + Returns: dict: Containing 'valid'(bool) and 'error_url'(str) if invalid. @@ -299,6 +276,9 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def _submit_captcha(self, ip_addr, private_key, challenge, response): + """ + Used only by c/s api v1 + """ # TODO: get this from the homeserver rather than creating a new one for # each request client = CaptchaServerHttpClient(self.hs) diff --git a/synapse/http/client.py b/synapse/http/client.py index 2ae1c4d3a4..e8a5dedab4 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -200,6 +200,8 @@ class CaptchaServerHttpClient(SimpleHttpClient): """ Separate HTTP client for talking to google's captcha servers Only slightly special because accepts partial download responses + + used only by c/s api v1 """ @defer.inlineCallbacks diff --git a/synapse/http/server.py b/synapse/http/server.py index b3706889ab..05636e683b 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -131,10 +131,10 @@ class HttpServer(object): """ def register_path(self, method, path_pattern, callback): - """ Register a callback that get's fired if we receive a http request + """ Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. - If the regex contains groups these get's passed to the calback via + If the regex contains groups these gets passed to the calback via an unpacked tuple. Args: @@ -153,6 +153,13 @@ class JsonResource(HttpServer, resource.Resource): Resources. Register callbacks via register_path() + + Callbacks can return a tuple of status code and a dict in which case the + the dict will automatically be sent to the client as a JSON object. + + The JsonResource is primarily intended for returning JSON, but callbacks + may send something other than JSON, they may do so by using the methods + on the request object and instead returning None. """ isLeaf = True @@ -185,9 +192,8 @@ class JsonResource(HttpServer, resource.Resource): interface=self.hs.config.bind_host ) - # Gets called by twisted def render(self, request): - """ This get's called by twisted every time someone sends us a request. + """ This gets called by twisted every time someone sends us a request. """ self._async_render(request) return server.NOT_DONE_YET @@ -195,7 +201,7 @@ class JsonResource(HttpServer, resource.Resource): @request_handler @defer.inlineCallbacks def _async_render(self, request): - """ This get's called by twisted every time someone sends us a request. + """ This gets called from render() every time someone sends us a request. This checks if anyone has registered a callback for that method and path. """ @@ -227,9 +233,11 @@ class JsonResource(HttpServer, resource.Resource): urllib.unquote(u).decode("UTF-8") for u in m.groups() ] - code, response = yield callback(request, *args) + callback_return = yield callback(request, *args) + if callback_return is not None: + code, response = callback_return + self._send_response(request, code, response) - self._send_response(request, code, response) response_timer.inc_by( self.clock.time_msec() - start, request.method, servlet_classname ) diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 0727f772a5..5575c847f9 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -253,7 +253,8 @@ class Pusher(object): self.user_name, config, timeout=0) self.last_token = chunk['end'] self.store.update_pusher_last_token( - self.app_id, self.pushkey, self.last_token) + self.app_id, self.pushkey, self.user_name, self.last_token + ) logger.info("Pusher %s for user %s starting from token %s", self.pushkey, self.user_name, self.last_token) @@ -314,7 +315,7 @@ class Pusher(object): pk ) yield self.hs.get_pusherpool().remove_pusher( - self.app_id, pk + self.app_id, pk, self.user_name ) if not self.alive: @@ -326,6 +327,7 @@ class Pusher(object): self.store.update_pusher_last_token_and_success( self.app_id, self.pushkey, + self.user_name, self.last_token, self.clock.time_msec() ) @@ -334,6 +336,7 @@ class Pusher(object): self.store.update_pusher_failing_since( self.app_id, self.pushkey, + self.user_name, self.failing_since) else: if not self.failing_since: @@ -341,6 +344,7 @@ class Pusher(object): self.store.update_pusher_failing_since( self.app_id, self.pushkey, + self.user_name, self.failing_since ) @@ -358,6 +362,7 @@ class Pusher(object): self.store.update_pusher_last_token( self.app_id, self.pushkey, + self.user_name, self.last_token ) @@ -365,6 +370,7 @@ class Pusher(object): self.store.update_pusher_failing_since( self.app_id, self.pushkey, + self.user_name, self.failing_since ) else: diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 90babd7224..041ce8f22a 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -57,7 +57,7 @@ class PusherPool: self._start_pushers(pushers) @defer.inlineCallbacks - def add_pusher(self, user_name, profile_tag, kind, app_id, + def add_pusher(self, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, lang, data): # we try to create the pusher just to validate the config: it # will then get pulled out of the database, @@ -71,7 +71,7 @@ class PusherPool: "app_display_name": app_display_name, "device_display_name": device_display_name, "pushkey": pushkey, - "pushkey_ts": self.hs.get_clock().time_msec(), + "ts": self.hs.get_clock().time_msec(), "lang": lang, "data": data, "last_token": None, @@ -79,17 +79,50 @@ class PusherPool: "failing_since": None }) yield self._add_pusher_to_store( - user_name, profile_tag, kind, app_id, + user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, lang, data ) @defer.inlineCallbacks - def _add_pusher_to_store(self, user_name, profile_tag, kind, app_id, - app_display_name, device_display_name, + def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, + not_user_id): + to_remove = yield self.store.get_pushers_by_app_id_and_pushkey( + app_id, pushkey + ) + for p in to_remove: + if p['user_name'] != not_user_id: + logger.info( + "Removing pusher for app id %s, pushkey %s, user %s", + app_id, pushkey, p['user_name'] + ) + self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) + + @defer.inlineCallbacks + def remove_pushers_by_user_access_token(self, user_id, not_access_token_id): + all = yield self.store.get_all_pushers() + logger.info( + "Removing all pushers for user %s except access token %s", + user_id, not_access_token_id + ) + for p in all: + if ( + p['user_name'] == user_id and + p['access_token'] != not_access_token_id + ): + logger.info( + "Removing pusher for app id %s, pushkey %s, user %s", + p['app_id'], p['pushkey'], p['user_name'] + ) + self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) + + @defer.inlineCallbacks + def _add_pusher_to_store(self, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, pushkey, lang, data): yield self.store.add_pusher( user_name=user_name, + access_token=access_token, profile_tag=profile_tag, kind=kind, app_id=app_id, @@ -100,7 +133,7 @@ class PusherPool: lang=lang, data=encode_canonical_json(data).decode("UTF-8"), ) - self._refresh_pusher((app_id, pushkey)) + self._refresh_pusher(app_id, pushkey, user_name) def _create_pusher(self, pusherdict): if pusherdict['kind'] == 'http': @@ -112,7 +145,7 @@ class PusherPool: app_display_name=pusherdict['app_display_name'], device_display_name=pusherdict['device_display_name'], pushkey=pusherdict['pushkey'], - pushkey_ts=pusherdict['pushkey_ts'], + pushkey_ts=pusherdict['ts'], data=pusherdict['data'], last_token=pusherdict['last_token'], last_success=pusherdict['last_success'], @@ -125,30 +158,42 @@ class PusherPool: ) @defer.inlineCallbacks - def _refresh_pusher(self, app_id_pushkey): - p = yield self.store.get_pushers_by_app_id_and_pushkey( - app_id_pushkey + def _refresh_pusher(self, app_id, pushkey, user_name): + resultlist = yield self.store.get_pushers_by_app_id_and_pushkey( + app_id, pushkey ) - p['data'] = json.loads(p['data']) + p = None + for r in resultlist: + if r['user_name'] == user_name: + p = r + + if p: + p['data'] = json.loads(p['data']) - self._start_pushers([p]) + self._start_pushers([p]) def _start_pushers(self, pushers): logger.info("Starting %d pushers", len(pushers)) for pusherdict in pushers: p = self._create_pusher(pusherdict) if p: - fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey']) + fullid = "%s:%s:%s" % ( + pusherdict['app_id'], + pusherdict['pushkey'], + pusherdict['user_name'] + ) if fullid in self.pushers: self.pushers[fullid].stop() self.pushers[fullid] = p p.start() @defer.inlineCallbacks - def remove_pusher(self, app_id, pushkey): - fullid = "%s:%s" % (app_id, pushkey) + def remove_pusher(self, app_id, pushkey, user_name): + fullid = "%s:%s:%s" % (app_id, pushkey, user_name) if fullid in self.pushers: logger.info("Stopping pusher %s", fullid) self.pushers[fullid].stop() del self.pushers[fullid] - yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey) + yield self.store.delete_pusher_by_app_id_pushkey_user_name( + app_id, pushkey, user_name + ) diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 72332bdb10..504a5e432f 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -48,5 +48,5 @@ class ClientV1RestServlet(RestServlet): self.hs = hs self.handlers = hs.get_handlers() self.builder_factory = hs.get_event_builder_factory() - self.auth = hs.get_auth() + self.auth = hs.get_v1auth() self.txns = HttpTransactionStore() diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 6045e86f34..c83287c028 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - user, _ = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -37,7 +37,7 @@ class PusherRestServlet(ClientV1RestServlet): and 'kind' in content and content['kind'] is None): yield pusher_pool.remove_pusher( - content['app_id'], content['pushkey'] + content['app_id'], content['pushkey'], user_name=user.to_string() ) defer.returnValue((200, {})) @@ -51,9 +51,21 @@ class PusherRestServlet(ClientV1RestServlet): raise SynapseError(400, "Missing parameters: "+','.join(missing), errcode=Codes.MISSING_PARAM) + append = False + if 'append' in content: + append = content['append'] + + if not append: + yield pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( + app_id=content['app_id'], + pushkey=content['pushkey'], + not_user_id=user.to_string() + ) + try: yield pusher_pool.add_pusher( user_name=user.to_string(), + access_token=client.token_id, profile_tag=content['profile_tag'], kind=content['kind'], app_id=content['app_id'], diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py index bca65f2a6a..28d95b2729 100644 --- a/synapse/rest/client/v2_alpha/__init__.py +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -15,7 +15,10 @@ from . import ( sync, - filter + filter, + account, + register, + auth ) from synapse.http.server import JsonResource @@ -32,3 +35,6 @@ class ClientV2AlphaRestResource(JsonResource): def register_servlets(client_resource, hs): sync.register_servlets(hs, client_resource) filter.register_servlets(hs, client_resource) + account.register_servlets(hs, client_resource) + register.register_servlets(hs, client_resource) + auth.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 22dc5cb862..4540e8dcf7 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -17,9 +17,11 @@ """ from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +from synapse.api.errors import SynapseError import re import logging +import simplejson logger = logging.getLogger(__name__) @@ -36,3 +38,23 @@ def client_v2_pattern(path_regex): SRE_Pattern """ return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex) + + +def parse_request_allow_empty(request): + content = request.content.read() + if content is None or content == '': + return None + try: + return simplejson.loads(content) + except simplejson.JSONDecodeError: + raise SynapseError(400, "Content not JSON.") + + +def parse_json_dict_from_request(request): + try: + content = simplejson.loads(request.content.read()) + if type(content) != dict: + raise SynapseError(400, "Content must be a JSON object.") + return content + except simplejson.JSONDecodeError: + raise SynapseError(400, "Content not JSON.") diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py new file mode 100644 index 0000000000..3e522ad39b --- /dev/null +++ b/synapse/rest/client/v2_alpha/account.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.constants import LoginType +from synapse.api.errors import LoginError, SynapseError, Codes +from synapse.http.servlet import RestServlet +from synapse.util.async import run_on_reactor + +from ._base import client_v2_pattern, parse_json_dict_from_request + +import logging + + +logger = logging.getLogger(__name__) + + +class PasswordRestServlet(RestServlet): + PATTERN = client_v2_pattern("/account/password") + + def __init__(self, hs): + super(PasswordRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_handlers().auth_handler + self.login_handler = hs.get_handlers().login_handler + + @defer.inlineCallbacks + def on_POST(self, request): + yield run_on_reactor() + + body = parse_json_dict_from_request(request) + + authed, result, params = yield self.auth_handler.check_auth([ + [LoginType.PASSWORD], + [LoginType.EMAIL_IDENTITY] + ], body) + + if not authed: + defer.returnValue((401, result)) + + user_id = None + + if LoginType.PASSWORD in result: + # if using password, they should also be logged in + auth_user, client = yield self.auth.get_user_by_req(request) + if auth_user.to_string() != result[LoginType.PASSWORD]: + raise LoginError(400, "", Codes.UNKNOWN) + user_id = auth_user.to_string() + elif LoginType.EMAIL_IDENTITY in result: + threepid = result[LoginType.EMAIL_IDENTITY] + if 'medium' not in threepid or 'address' not in threepid: + raise SynapseError(500, "Malformed threepid") + # if using email, we must know about the email they're authing with! + threepid_user = yield self.hs.get_datastore().get_user_by_threepid( + threepid['medium'], threepid['address'] + ) + if not threepid_user: + raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) + user_id = threepid_user + else: + logger.error("Auth succeeded but no known type!", result.keys()) + raise SynapseError(500, "", Codes.UNKNOWN) + + if 'new_password' not in params: + raise SynapseError(400, "", Codes.MISSING_PARAM) + new_password = params['new_password'] + + yield self.login_handler.set_password( + user_id, new_password, None + ) + + defer.returnValue((200, {})) + + def on_OPTIONS(self, _): + return 200, {} + + +class ThreepidRestServlet(RestServlet): + PATTERN = client_v2_pattern("/account/3pid") + + def __init__(self, hs): + super(ThreepidRestServlet, self).__init__() + self.hs = hs + self.login_handler = hs.get_handlers().login_handler + self.identity_handler = hs.get_handlers().identity_handler + self.auth = hs.get_auth() + + @defer.inlineCallbacks + def on_GET(self, request): + yield run_on_reactor() + + auth_user, _ = yield self.auth.get_user_by_req(request) + + threepids = yield self.hs.get_datastore().user_get_threepids( + auth_user.to_string() + ) + + defer.returnValue((200, {'threepids': threepids})) + + @defer.inlineCallbacks + def on_POST(self, request): + yield run_on_reactor() + + body = parse_json_dict_from_request(request) + + if 'threePidCreds' not in body: + raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) + threePidCreds = body['threePidCreds'] + + auth_user, client = yield self.auth.get_user_by_req(request) + + threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) + + if not threepid: + raise SynapseError( + 400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED + ) + + for reqd in ['medium', 'address', 'validated_at']: + if reqd not in threepid: + logger.warn("Couldn't add 3pid: invalid response from ID sevrer") + raise SynapseError(500, "Invalid response from ID Server") + + yield self.login_handler.add_threepid( + auth_user.to_string(), + threepid['medium'], + threepid['address'], + threepid['validated_at'], + ) + + if 'bind' in body and body['bind']: + logger.debug( + "Binding emails %s to %s", + threepid, auth_user.to_string() + ) + yield self.identity_handler.bind_threepid( + threePidCreds, auth_user.to_string() + ) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + PasswordRestServlet(hs).register(http_server) + ThreepidRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py new file mode 100644 index 0000000000..4c726f05f5 --- /dev/null +++ b/synapse/rest/client/v2_alpha/auth.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.constants import LoginType +from synapse.api.errors import SynapseError +from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +from synapse.http.servlet import RestServlet + +from ._base import client_v2_pattern + +import logging + + +logger = logging.getLogger(__name__) + +RECAPTCHA_TEMPLATE = """ +<html> +<head> +<title>Authentication</title> +<meta name='viewport' content='width=device-width, initial-scale=1, + user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> +<script src="https://www.google.com/recaptcha/api.js" + async defer></script> +<script src="//code.jquery.com/jquery-1.11.2.min.js"></script> +<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> +<script> +function captchaDone() { + $('#registrationForm').submit(); +} +</script> +</head> +<body> +<form id="registrationForm" method="post" action="%(myurl)s"> + <div> + <p> + Hello! We need to prevent computer programs and other automated + things from creating accounts on this server. + </p> + <p> + Please verify that you're not a robot. + </p> + <input type="hidden" name="session" value="%(session)s" /> + <div class="g-recaptcha" + data-sitekey="%(sitekey)s" + data-callback="captchaDone"> + </div> + <noscript> + <input type="submit" value="All Done" /> + </noscript> + </div> + </div> +</form> +</body> +</html> +""" + +SUCCESS_TEMPLATE = """ +<html> +<head> +<title>Success!</title> +<meta name='viewport' content='width=device-width, initial-scale=1, + user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> +<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> +<script> +if (window.onAuthDone != undefined) { + window.onAuthDone(); +} +</script> +</head> +<body> + <div> + <p>Thank you</p> + <p>You may now close this window and return to the application</p> + </div> +</body> +</html> +""" + + +class AuthRestServlet(RestServlet): + """ + Handles Client / Server API authentication in any situations where it + cannot be handled in the normal flow (with requests to the same endpoint). + Current use is for web fallback auth. + """ + PATTERN = client_v2_pattern("/auth/(?P<stagetype>[\w\.]*)/fallback/web") + + def __init__(self, hs): + super(AuthRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_handlers().auth_handler + self.registration_handler = hs.get_handlers().registration_handler + + @defer.inlineCallbacks + def on_GET(self, request, stagetype): + yield + if stagetype == LoginType.RECAPTCHA: + if ('session' not in request.args or + len(request.args['session']) == 0): + raise SynapseError(400, "No session supplied") + + session = request.args["session"][0] + + html = RECAPTCHA_TEMPLATE % { + 'session': session, + 'myurl': "%s/auth/%s/fallback/web" % ( + CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA + ), + 'sitekey': self.hs.config.recaptcha_public_key, + } + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Server", self.hs.version_string) + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + request.finish() + defer.returnValue(None) + else: + raise SynapseError(404, "Unknown auth stage type") + + @defer.inlineCallbacks + def on_POST(self, request, stagetype): + yield + if stagetype == "m.login.recaptcha": + if ('g-recaptcha-response' not in request.args or + len(request.args['g-recaptcha-response'])) == 0: + raise SynapseError(400, "No captcha response supplied") + if ('session' not in request.args or + len(request.args['session'])) == 0: + raise SynapseError(400, "No session supplied") + + session = request.args['session'][0] + + authdict = { + 'response': request.args['g-recaptcha-response'][0], + 'session': session, + } + + success = yield self.auth_handler.add_oob_auth( + LoginType.RECAPTCHA, + authdict, + self.hs.get_ip_from_request(request) + ) + + if success: + html = SUCCESS_TEMPLATE + else: + html = RECAPTCHA_TEMPLATE % { + 'session': session, + 'myurl': "%s/auth/%s/fallback/web" % ( + CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA + ), + 'sitekey': self.hs.config.recaptcha_public_key, + } + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Server", self.hs.version_string) + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + request.finish() + + defer.returnValue(None) + else: + raise SynapseError(404, "Unknown auth stage type") + + def on_OPTIONS(self, _): + return 200, {} + + +def register_servlets(hs, http_server): + AuthRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py new file mode 100644 index 0000000000..3640fb4a29 --- /dev/null +++ b/synapse/rest/client/v2_alpha/register.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.constants import LoginType +from synapse.api.errors import SynapseError, Codes +from synapse.http.servlet import RestServlet + +from ._base import client_v2_pattern, parse_request_allow_empty + +import logging +import hmac +from hashlib import sha1 +from synapse.util.async import run_on_reactor + + +# We ought to be using hmac.compare_digest() but on older pythons it doesn't +# exist. It's a _really minor_ security flaw to use plain string comparison +# because the timing attack is so obscured by all the other code here it's +# unlikely to make much difference +if hasattr(hmac, "compare_digest"): + compare_digest = hmac.compare_digest +else: + compare_digest = lambda a, b: a == b + + +logger = logging.getLogger(__name__) + + +class RegisterRestServlet(RestServlet): + PATTERN = client_v2_pattern("/register") + + def __init__(self, hs): + super(RegisterRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_handlers().auth_handler + self.registration_handler = hs.get_handlers().registration_handler + self.identity_handler = hs.get_handlers().identity_handler + self.login_handler = hs.get_handlers().login_handler + + @defer.inlineCallbacks + def on_POST(self, request): + yield run_on_reactor() + + body = parse_request_allow_empty(request) + if 'password' not in body: + raise SynapseError(400, "", Codes.MISSING_PARAM) + + if 'username' in body: + desired_username = body['username'] + yield self.registration_handler.check_username(desired_username) + + is_using_shared_secret = False + is_application_server = False + + service = None + if 'access_token' in request.args: + service = yield self.auth.get_appservice_by_req(request) + + if self.hs.config.enable_registration_captcha: + flows = [ + [LoginType.RECAPTCHA], + [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA] + ] + else: + flows = [ + [LoginType.DUMMY], + [LoginType.EMAIL_IDENTITY] + ] + + if service: + is_application_server = True + elif 'mac' in body: + # Check registration-specific shared secret auth + if 'username' not in body: + raise SynapseError(400, "", Codes.MISSING_PARAM) + self._check_shared_secret_auth( + body['username'], body['mac'] + ) + is_using_shared_secret = True + else: + authed, result, params = yield self.auth_handler.check_auth( + flows, body, self.hs.get_ip_from_request(request) + ) + + if not authed: + defer.returnValue((401, result)) + + can_register = ( + not self.hs.config.disable_registration + or is_application_server + or is_using_shared_secret + ) + if not can_register: + raise SynapseError(403, "Registration has been disabled") + + if 'password' not in params: + raise SynapseError(400, "", Codes.MISSING_PARAM) + desired_username = params['username'] if 'username' in params else None + new_password = params['password'] + + (user_id, token) = yield self.registration_handler.register( + localpart=desired_username, + password=new_password + ) + + if LoginType.EMAIL_IDENTITY in result: + threepid = result[LoginType.EMAIL_IDENTITY] + + for reqd in ['medium', 'address', 'validated_at']: + if reqd not in threepid: + logger.info("Can't add incomplete 3pid") + else: + yield self.login_handler.add_threepid( + user_id, + threepid['medium'], + threepid['address'], + threepid['validated_at'], + ) + + if 'bind_email' in params and params['bind_email']: + logger.info("bind_email specified: binding") + + emailThreepid = result[LoginType.EMAIL_IDENTITY] + threepid_creds = emailThreepid['threepid_creds'] + logger.debug("Binding emails %s to %s" % ( + emailThreepid, user_id + )) + yield self.identity_handler.bind_threepid(threepid_creds, user_id) + else: + logger.info("bind_email not specified: not binding email") + + result = { + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname, + } + + defer.returnValue((200, result)) + + def on_OPTIONS(self, _): + return 200, {} + + def _check_shared_secret_auth(self, username, mac): + if not self.hs.config.registration_shared_secret: + raise SynapseError(400, "Shared secret registration is not enabled") + + user = username.encode("utf-8") + + # str() because otherwise hmac complains that 'unicode' does not + # have the buffer interface + got_mac = str(mac) + + want_mac = hmac.new( + key=self.hs.config.registration_shared_secret, + msg=user, + digestmod=sha1, + ).hexdigest() + + if compare_digest(want_mac, got_mac): + return True + else: + raise SynapseError( + 403, "HMAC incorrect", + ) + + +def register_servlets(hs, http_server): + RegisterRestServlet(hs).register(http_server) diff --git a/synapse/server.py b/synapse/server.py index 0bd87bdd77..af87dab12c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -65,6 +65,7 @@ class BaseHomeServer(object): 'replication_layer', 'datastore', 'handlers', + 'v1auth', 'auth', 'rest_servlet_factory', 'state_handler', @@ -181,6 +182,15 @@ class HomeServer(BaseHomeServer): def build_auth(self): return Auth(self) + def build_v1auth(self): + orf = Auth(self) + # Matrix spec makes no reference to what HTTP status code is returned, + # but the V1 API uses 403 where it means 401, and the webclient + # relies on this behaviour, so V1 gets its own copy of the auth + # with backwards compat behaviour. + orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403 + return orf + def build_state_handler(self): return StateHandler(self) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 000502b4ff..1c657beddb 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -27,67 +27,40 @@ logger = logging.getLogger(__name__) class PusherStore(SQLBaseStore): @defer.inlineCallbacks - def get_pushers_by_app_id_and_pushkey(self, app_id_and_pushkey): + def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): + cols = ",".join(PushersTable.fields) sql = ( - "SELECT id, user_name, kind, profile_tag, app_id," - "app_display_name, device_display_name, pushkey, ts, data, " - "last_token, last_success, failing_since " - "FROM pushers " + "SELECT "+cols+" FROM pushers " "WHERE app_id = ? AND pushkey = ?" ) rows = yield self._execute( "get_pushers_by_app_id_and_pushkey", None, sql, - app_id_and_pushkey[0], app_id_and_pushkey[1] + app_id, pushkey ) ret = [ { - "id": r[0], - "user_name": r[1], - "kind": r[2], - "profile_tag": r[3], - "app_id": r[4], - "app_display_name": r[5], - "device_display_name": r[6], - "pushkey": r[7], - "pushkey_ts": r[8], - "data": r[9], - "last_token": r[10], - "last_success": r[11], - "failing_since": r[12] + k: r[i] for i, k in enumerate(PushersTable.fields) } for r in rows ] + print ret - defer.returnValue(ret[0]) + defer.returnValue(ret) @defer.inlineCallbacks def get_all_pushers(self): + cols = ",".join(PushersTable.fields) sql = ( - "SELECT id, user_name, kind, profile_tag, app_id," - "app_display_name, device_display_name, pushkey, ts, data, " - "last_token, last_success, failing_since " - "FROM pushers" + "SELECT "+cols+" FROM pushers" ) rows = yield self._execute("get_all_pushers", None, sql) ret = [ { - "id": r[0], - "user_name": r[1], - "kind": r[2], - "profile_tag": r[3], - "app_id": r[4], - "app_display_name": r[5], - "device_display_name": r[6], - "pushkey": r[7], - "pushkey_ts": r[8], - "data": r[9], - "last_token": r[10], - "last_success": r[11], - "failing_since": r[12] + k: r[i] for i, k in enumerate(PushersTable.fields) } for r in rows ] @@ -95,7 +68,7 @@ class PusherStore(SQLBaseStore): defer.returnValue(ret) @defer.inlineCallbacks - def add_pusher(self, user_name, profile_tag, kind, app_id, + def add_pusher(self, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, lang, data): try: @@ -104,9 +77,10 @@ class PusherStore(SQLBaseStore): dict( app_id=app_id, pushkey=pushkey, + user_name=user_name, ), dict( - user_name=user_name, + access_token=access_token, kind=kind, profile_tag=profile_tag, app_display_name=app_display_name, @@ -122,37 +96,38 @@ class PusherStore(SQLBaseStore): raise StoreError(500, "Problem creating pusher.") @defer.inlineCallbacks - def delete_pusher_by_app_id_pushkey(self, app_id, pushkey): + def delete_pusher_by_app_id_pushkey_user_name(self, app_id, pushkey, user_name): yield self._simple_delete_one( PushersTable.table_name, - {"app_id": app_id, "pushkey": pushkey}, - desc="delete_pusher_by_app_id_pushkey", + {"app_id": app_id, "pushkey": pushkey, 'user_name': user_name}, + desc="delete_pusher_by_app_id_pushkey_user_name", ) @defer.inlineCallbacks - def update_pusher_last_token(self, app_id, pushkey, last_token): + def update_pusher_last_token(self, app_id, pushkey, user_name, last_token): yield self._simple_update_one( PushersTable.table_name, - {'app_id': app_id, 'pushkey': pushkey}, + {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name}, {'last_token': last_token}, desc="update_pusher_last_token", ) @defer.inlineCallbacks - def update_pusher_last_token_and_success(self, app_id, pushkey, + def update_pusher_last_token_and_success(self, app_id, pushkey, user_name, last_token, last_success): yield self._simple_update_one( PushersTable.table_name, - {'app_id': app_id, 'pushkey': pushkey}, + {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name}, {'last_token': last_token, 'last_success': last_success}, desc="update_pusher_last_token_and_success", ) @defer.inlineCallbacks - def update_pusher_failing_since(self, app_id, pushkey, failing_since): + def update_pusher_failing_since(self, app_id, pushkey, user_name, + failing_since): yield self._simple_update_one( PushersTable.table_name, - {'app_id': app_id, 'pushkey': pushkey}, + {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name}, {'failing_since': failing_since}, desc="update_pusher_failing_since", ) @@ -164,13 +139,15 @@ class PushersTable(Table): fields = [ "id", "user_name", + "access_token", "kind", "profile_tag", "app_id", "app_display_name", "device_display_name", "pushkey", - "pushkey_ts", + "ts", + "lang", "data", "last_token", "last_success", diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index f24154f146..f85cbb0d9d 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -89,17 +89,48 @@ class RegistrationStore(SQLBaseStore): "VALUES (?,?)", [txn.lastrowid, token]) def get_user_by_id(self, user_id): - query = ("SELECT users.name, users.password_hash FROM users" + query = ("SELECT users.id, users.name, users.password_hash FROM users" " WHERE users.name = ?") return self._execute( "get_user_by_id", self.cursor_to_dict, query, user_id ) + @defer.inlineCallbacks + def user_set_password_hash(self, user_id, password_hash): + """ + NB. This does *not* evict any cache because the one use for this + removes most of the entries subsequently anyway so it would be + pointless. Use flush_user separately. + """ + yield self._simple_update_one('users', { + 'name': user_id + }, { + 'password_hash': password_hash + }) + + @defer.inlineCallbacks + def user_delete_access_tokens_apart_from(self, user_id, token_id): + rows = yield self.get_user_by_id(user_id) + if len(rows) == 0: + raise Exception("No such user!") + + yield self._execute( + "delete_access_tokens_apart_from", None, + "DELETE FROM access_tokens WHERE user_id = ? AND id != ?", + rows[0]['id'], token_id + ) + + @defer.inlineCallbacks + def flush_user(self, user_id): + rows = yield self._execute( + 'flush_user', None, + "SELECT token FROM access_tokens WHERE user_id = ?", + user_id + ) + for r in rows: + self.get_user_by_token.invalidate(r) + @cached() - # TODO(paul): Currently there's no code to invalidate this cache. That - # means if/when we ever add internal ways to invalidate access tokens or - # change whether a user is a server admin, those will need to invoke - # store.get_user_by_token.invalidate(token) def get_user_by_token(self, token): """Get a user from the given access token. @@ -143,4 +174,40 @@ class RegistrationStore(SQLBaseStore): if rows: return rows[0] - raise StoreError(404, "Token not found.") + return None + + @defer.inlineCallbacks + def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + yield self._simple_upsert("user_threepids", { + "user": user_id, + "medium": medium, + "address": address, + }, { + "validated_at": validated_at, + "added_at": added_at, + }) + + @defer.inlineCallbacks + def user_get_threepids(self, user_id): + ret = yield self._simple_select_list( + "user_threepids", { + "user": user_id + }, + ['medium', 'address', 'validated_at', 'added_at'], + 'user_get_threepids' + ) + defer.returnValue(ret) + + @defer.inlineCallbacks + def get_user_by_threepid(self, medium, address): + ret = yield self._simple_select_one( + "user_threepids", + { + "medium": medium, + "address": address + }, + ['user'], True, 'get_user_by_threepid' + ) + if ret: + defer.returnValue(ret['user']) + defer.returnValue(None) diff --git a/synapse/storage/schema/delta/15/v15.sql b/synapse/storage/schema/delta/15/v15.sql new file mode 100644 index 0000000000..f5b2a08ca4 --- /dev/null +++ b/synapse/storage/schema/delta/15/v15.sql @@ -0,0 +1,25 @@ +-- Drop, copy & recreate pushers table to change unique key +-- Also add access_token column at the same time +CREATE TABLE IF NOT EXISTS pushers2 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + access_token INTEGER DEFAULT NULL, + profile_tag varchar(32) NOT NULL, + kind varchar(8) NOT NULL, + app_id varchar(64) NOT NULL, + app_display_name varchar(64) NOT NULL, + device_display_name varchar(128) NOT NULL, + pushkey blob NOT NULL, + ts BIGINT NOT NULL, + lang varchar(8), + data blob, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + FOREIGN KEY(user_name) REFERENCES users(name), + UNIQUE (app_id, pushkey, user_name) +); +INSERT INTO pushers2 (id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since) + SELECT id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers; +DROP TABLE pushers; +ALTER TABLE pushers2 RENAME TO pushers; diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index b9c03383a2..8e0c5fa630 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -75,7 +75,7 @@ class PresenceStateTestCase(unittest.TestCase): "token_id": 1, } - hs.get_auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_token = _get_user_by_token room_member_handler = hs.handlers.room_member_handler = Mock( spec=[ @@ -170,7 +170,7 @@ class PresenceListTestCase(unittest.TestCase): ] ) - hs.get_auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_token = _get_user_by_token presence.register_servlets(hs, self.mock_resource) @@ -277,7 +277,7 @@ class PresenceEventStreamTestCase(unittest.TestCase): def _get_user_by_req(req=None): return (UserID.from_string(myid), "") - hs.get_auth().get_user_by_req = _get_user_by_req + hs.get_v1auth().get_user_by_req = _get_user_by_req presence.register_servlets(hs, self.mock_resource) events.register_servlets(hs, self.mock_resource) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 5cd5767f2e..929e5e5dd4 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -55,7 +55,7 @@ class ProfileTestCase(unittest.TestCase): def _get_user_by_req(request=None): return (UserID.from_string(myid), "") - hs.get_auth().get_user_by_req = _get_user_by_req + hs.get_v1auth().get_user_by_req = _get_user_by_req hs.get_handlers().profile_handler = self.mock_handler diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 72fb4576b1..c83348acf9 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -61,7 +61,7 @@ class RoomPermissionsTestCase(RestTestCase): "device_id": None, "token_id": 1, } - hs.get_auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_token = _get_user_by_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -71,7 +71,7 @@ class RoomPermissionsTestCase(RestTestCase): synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) - self.auth = hs.get_auth() + self.auth = hs.get_v1auth() # create some rooms under the name rmcreator_id self.uncreated_rmid = "!aa:test" @@ -448,7 +448,7 @@ class RoomsMemberListTestCase(RestTestCase): "device_id": None, "token_id": 1, } - hs.get_auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_token = _get_user_by_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -528,7 +528,7 @@ class RoomsCreateTestCase(RestTestCase): "device_id": None, "token_id": 1, } - hs.get_auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_token = _get_user_by_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -622,7 +622,7 @@ class RoomTopicTestCase(RestTestCase): "token_id": 1, } - hs.get_auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_token = _get_user_by_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -728,7 +728,7 @@ class RoomMemberStateTestCase(RestTestCase): "device_id": None, "token_id": 1, } - hs.get_auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_token = _get_user_by_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -855,7 +855,7 @@ class RoomMessagesTestCase(RestTestCase): "device_id": None, "token_id": 1, } - hs.get_auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_token = _get_user_by_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -952,7 +952,7 @@ class RoomInitialSyncTestCase(RestTestCase): "device_id": None, "token_id": 1, } - hs.get_auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_token = _get_user_by_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 7b3bd87439..7d8b1c2683 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -69,7 +69,7 @@ class RoomTypingTestCase(RestTestCase): "token_id": 1, } - hs.get_auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_token = _get_user_by_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index e0b81f2b57..2f8953f518 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -38,13 +38,12 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_register(self): yield self.store.register(self.user_id, self.tokens[0], self.pwhash) + u = (yield self.store.get_user_by_id(self.user_id))[0] - self.assertEquals( - # TODO(paul): Surely this field should be 'user_id', not 'name' - # Additionally surely it shouldn't come in a 1-element list - [{"name": self.user_id, "password_hash": self.pwhash}], - (yield self.store.get_user_by_id(self.user_id)) - ) + # TODO(paul): Surely this field should be 'user_id', not 'name' + # Additionally surely it shouldn't come in a 1-element list + self.assertEquals(self.user_id, u['name']) + self.assertEquals(self.pwhash, u['password_hash']) self.assertEquals( {"admin": 0, |