From af691e415c3247b912137227a06a68d4c4356586 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 18 Feb 2019 16:49:38 +0000 Subject: Move register_device into handler --- synapse/handlers/register.py | 51 ++++++++++++++++++++++++--- synapse/replication/http/login.py | 17 ++------- synapse/rest/client/v1/login.py | 59 ++++++++++++-------------------- synapse/rest/client/v2_alpha/register.py | 49 ++------------------------ 4 files changed, 74 insertions(+), 102 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 8ea557a003..f92ab4d525 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -27,6 +27,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.http.client import CaptchaServerHttpClient +from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.replication.http.register import ReplicationRegisterServlet from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.util.async_helpers import Linearizer @@ -64,6 +65,11 @@ class RegistrationHandler(BaseHandler): if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) + self._register_device_client = ( + RegisterDeviceReplicationServlet.make_client(hs) + ) + else: + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def check_username(self, localpart, guest_access_token=None, @@ -159,7 +165,7 @@ class RegistrationHandler(BaseHandler): yield self.auth.check_auth_blocking(threepid=threepid) password_hash = None if password: - password_hash = yield self.auth_handler().hash(password) + password_hash = yield self._auth_handler.hash(password) if localpart: yield self.check_username(localpart, guest_access_token=guest_access_token) @@ -516,9 +522,6 @@ class RegistrationHandler(BaseHandler): defer.returnValue((user_id, token)) - def auth_handler(self): - return self.hs.get_auth_handler() - @defer.inlineCallbacks def get_or_register_3pid_guest(self, medium, address, inviter_user_id): """Get a guest access token for a 3PID, creating a guest account if @@ -628,3 +631,43 @@ class RegistrationHandler(BaseHandler): admin=admin, user_type=user_type, ) + + @defer.inlineCallbacks + def register_device(self, user_id, device_id, initial_display_name, + is_guest=False): + """Register a device for a user and generate an access token. + + Args: + user_id (str): full canonical @user:id + device_id (str|None): The device ID to check, or None to generate + a new one. + initial_display_name (str|None): An optional display name for the + device. + is_guest (bool): Whether this is a guest account + + Returns: + defer.Deferred[tuple[str, str]]: Tuple of device ID and access token + """ + + if self.hs.config.worker_app: + r = yield self._register_device_client( + user_id=user_id, + device_id=device_id, + initial_display_name=initial_display_name, + is_guest=is_guest, + ) + defer.returnValue((r["device_id"], r["access_token"])) + else: + device_id = yield self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + if is_guest: + access_token = self.macaroon_gen.generate_access_token( + user_id, ["guest = true"] + ) + else: + access_token = yield self._auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, + ) + + defer.returnValue((device_id, access_token)) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 797f6aabd1..1590eca317 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -35,9 +35,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): def __init__(self, hs): super(RegisterDeviceReplicationServlet, self).__init__(hs) - self.auth_handler = hs.get_auth_handler() - self.device_handler = hs.get_device_handler() - self.macaroon_gen = hs.get_macaroon_generator() + self.registration_handler = hs.get_handlers().registration_handler @staticmethod def _serialize_payload(user_id, device_id, initial_display_name, is_guest): @@ -62,19 +60,10 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] - device_id = yield self.device_handler.check_device_registered( - user_id, device_id, initial_display_name, + device_id, access_token = yield self.registration_handler.register_device( + user_id, device_id, initial_display_name, is_guest, ) - if is_guest: - access_token = self.macaroon_gen.generate_access_token( - user_id, ["guest = true"] - ) - else: - access_token = yield self.auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, - ) - defer.returnValue((200, { "device_id": device_id, "access_token": access_token, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 942e4d3816..4a5775083f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -94,7 +94,7 @@ class LoginRestServlet(ClientV1RestServlet): self.jwt_algorithm = hs.config.jwt_algorithm self.cas_enabled = hs.config.cas_enabled self.auth_handler = self.hs.get_auth_handler() - self.device_handler = self.hs.get_device_handler() + self.registration_handler = hs.get_handlers().registration_handler self.handlers = hs.get_handlers() self._well_known_builder = WellKnownBuilder(hs) @@ -220,11 +220,10 @@ class LoginRestServlet(ClientV1RestServlet): login_submission, ) - device_id = yield self._register_device( - canonical_user_id, login_submission, - ) - access_token = yield auth_handler.get_access_token_for_user_id( - canonical_user_id, device_id, + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + device_id, access_token = yield self.registration_handler.register_device( + canonical_user_id, device_id, initial_display_name, ) result = { @@ -246,10 +245,13 @@ class LoginRestServlet(ClientV1RestServlet): user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) - device_id = yield self._register_device(user_id, login_submission) - access_token = yield auth_handler.get_access_token_for_user_id( - user_id, device_id, + + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + device_id, access_token = yield self.registration_handler.register_device( + user_id, device_id, initial_display_name, ) + result = { "user_id": user_id, # may have changed "access_token": access_token, @@ -286,11 +288,10 @@ class LoginRestServlet(ClientV1RestServlet): auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if registered_user_id: - device_id = yield self._register_device( - registered_user_id, login_submission - ) - access_token = yield auth_handler.get_access_token_for_user_id( - registered_user_id, device_id, + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + device_id, access_token = yield self.registration_handler.register_device( + registered_user_id, device_id, initial_display_name, ) result = { @@ -299,12 +300,16 @@ class LoginRestServlet(ClientV1RestServlet): "home_server": self.hs.hostname, } else: - # TODO: we should probably check that the register isn't going - # to fonx/change our user_id before registering the device - device_id = yield self._register_device(user_id, login_submission) user_id, access_token = ( yield self.handlers.registration_handler.register(localpart=user) ) + + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + device_id, access_token = yield self.registration_handler.register_device( + registered_user_id, device_id, initial_display_name, + ) + result = { "user_id": user_id, # may have changed "access_token": access_token, @@ -313,26 +318,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue(result) - def _register_device(self, user_id, login_submission): - """Register a device for a user. - - This is called after the user's credentials have been validated, but - before the access token has been issued. - - Args: - (str) user_id: full canonical @user:id - (object) login_submission: dictionary supplied to /login call, from - which we pull device_id and initial_device_name - Returns: - defer.Deferred: (str) device_id - """ - device_id = login_submission.get("device_id") - initial_display_name = login_submission.get( - "initial_device_display_name") - return self.device_handler.check_device_registered( - user_id, device_id, initial_display_name - ) - class CasRedirectServlet(RestServlet): PATTERNS = client_path_patterns("/login/(cas|sso)/redirect") diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index c52280c50c..c1cdb8f9c8 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -33,7 +33,6 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.threepids import check_3pid_allowed @@ -193,13 +192,6 @@ class RegisterRestServlet(RestServlet): self.room_member_handler = hs.get_room_member_handler() self.macaroon_gen = hs.get_macaroon_generator() - if self.hs.config.worker_app: - self._register_device_client = ( - RegisterDeviceReplicationServlet.make_client(hs) - ) - else: - self.device_handler = hs.get_device_handler() - @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): @@ -642,7 +634,7 @@ class RegisterRestServlet(RestServlet): if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") - device_id, access_token = yield self._register_device( + device_id, access_token = yield self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=False, ) @@ -652,43 +644,6 @@ class RegisterRestServlet(RestServlet): }) defer.returnValue(result) - @defer.inlineCallbacks - def _register_device(self, user_id, device_id, initial_display_name, - is_guest): - """Register a device for a user and generate an access token. - - Args: - user_id (str): full canonical @user:id - device_id (str|None): The device ID to check, or None to generate - a new one. - initial_display_name (str|None): An optional display name for the - device. - is_guest (bool): Whether this is a guest account - Returns: - defer.Deferred[tuple[str, str]]: Tuple of device ID and access token - """ - if self.hs.config.worker_app: - r = yield self._register_device_client( - user_id=user_id, - device_id=device_id, - initial_display_name=initial_display_name, - is_guest=is_guest, - ) - defer.returnValue((r["device_id"], r["access_token"])) - else: - device_id = yield self.device_handler.check_device_registered( - user_id, device_id, initial_display_name - ) - if is_guest: - access_token = self.macaroon_gen.generate_access_token( - user_id, ["guest = true"] - ) - else: - access_token = yield self.auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, - ) - defer.returnValue((device_id, access_token)) - @defer.inlineCallbacks def _do_guest_registration(self, params): if not self.hs.config.allow_guest_access: @@ -702,7 +657,7 @@ class RegisterRestServlet(RestServlet): # we have nowhere to store it. device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") - device_id, access_token = yield self._register_device( + device_id, access_token = yield self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=True, ) -- cgit 1.4.1