diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 7f812b8209..d78da50787 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -33,6 +33,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.threepids import check_3pid_allowed
@@ -190,9 +191,15 @@ class RegisterRestServlet(RestServlet):
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
self.room_member_handler = hs.get_room_member_handler()
- self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
+ if self.hs.config.worker_app:
+ self._register_device_client = (
+ RegisterDeviceReplicationServlet.make_client(hs)
+ )
+ else:
+ self.device_handler = hs.get_device_handler()
+
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
@@ -633,12 +640,10 @@ class RegisterRestServlet(RestServlet):
"home_server": self.hs.hostname,
}
if not params.get("inhibit_login", False):
- device_id = yield self._register_device(user_id, params)
-
- access_token = (
- yield self.auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id,
- )
+ device_id = params.get("device_id")
+ initial_display_name = params.get("initial_device_display_name")
+ device_id, access_token = yield self._register_device(
+ user_id, device_id, initial_display_name, is_guest=False,
)
result.update({
@@ -647,25 +652,42 @@ class RegisterRestServlet(RestServlet):
})
defer.returnValue(result)
- def _register_device(self, user_id, params):
- """Register a device for a user.
-
- This is called after the user's credentials have been validated, but
- before the access token has been issued.
+ @defer.inlineCallbacks
+ def _register_device(self, user_id, device_id, initial_display_name,
+ is_guest):
+ """Register a device for a user and generate an access token.
Args:
- (str) user_id: full canonical @user:id
- (object) params: registration parameters, from which we pull
- device_id and initial_device_name
+ user_id (str): full canonical @user:id
+ device_id (str|None): The device ID to check, or None to generate
+ a new one.
+ initial_display_name (str|None): An optional display name for the
+ device.
+ is_guest (bool): Whether this is a guest account
Returns:
- defer.Deferred: (str) device_id
+ defer.Deferred[(str, str)]: Tuple of device ID and access token
"""
- # register the user's device
- device_id = params.get("device_id")
- initial_display_name = params.get("initial_device_display_name")
- return self.device_handler.check_device_registered(
- user_id, device_id, initial_display_name
- )
+ if self.hs.config.worker_app:
+ r = yield self._register_device_client(
+ user_id=user_id,
+ device_id=device_id,
+ initial_display_name=initial_display_name,
+ is_guest=is_guest,
+ )
+ defer.returnValue((r["device_id"], r["access_token"]))
+ else:
+ device_id = yield self.device_handler.check_device_registered(
+ user_id, device_id, initial_display_name
+ )
+ if is_guest:
+ access_token = self.macaroon_gen.generate_access_token(
+ user_id, ["guest = true"]
+ )
+ else:
+ access_token = yield self.auth_handler.get_access_token_for_user_id(
+ user_id, device_id=device_id,
+ )
+ defer.returnValue((device_id, access_token))
@defer.inlineCallbacks
def _do_guest_registration(self, params):
@@ -680,13 +702,10 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name")
- yield self.device_handler.check_device_registered(
- user_id, device_id, initial_display_name
+ device_id, access_token = yield self._register_device(
+ user_id, device_id, initial_display_name, is_guest=True,
)
- access_token = self.macaroon_gen.generate_access_token(
- user_id, ["guest = true"]
- )
defer.returnValue((200, {
"user_id": user_id,
"device_id": device_id,
|