diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index b66f8756b8..cd001e87c7 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -16,7 +16,7 @@
"""Contains functions for registering clients."""
import logging
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from prometheus_client import Counter
@@ -82,6 +82,7 @@ class RegistrationHandler(BaseHandler):
)
else:
self.device_handler = hs.get_device_handler()
+ self._register_device_client = self.register_device_inner
self.pusher_pool = hs.get_pusherpool()
self.session_lifetime = hs.config.session_lifetime
@@ -678,17 +679,35 @@ class RegistrationHandler(BaseHandler):
Returns:
Tuple of device ID and access token
"""
+ res = await self._register_device_client(
+ user_id=user_id,
+ device_id=device_id,
+ initial_display_name=initial_display_name,
+ is_guest=is_guest,
+ is_appservice_ghost=is_appservice_ghost,
+ )
- if self.hs.config.worker_app:
- r = await self._register_device_client(
- user_id=user_id,
- device_id=device_id,
- initial_display_name=initial_display_name,
- is_guest=is_guest,
- is_appservice_ghost=is_appservice_ghost,
- )
- return r["device_id"], r["access_token"]
+ login_counter.labels(
+ guest=is_guest,
+ auth_provider=(auth_provider_id or ""),
+ ).inc()
+
+ return res["device_id"], res["access_token"]
+
+ async def register_device_inner(
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ is_guest: bool = False,
+ is_appservice_ghost: bool = False,
+ ) -> Dict[str, str]:
+ """Helper for register_device
+ Does the bits that need doing on the main process. Not for use outside this
+ class and RegisterDeviceReplicationServlet.
+ """
+ assert not self.hs.config.worker_app
valid_until_ms = None
if self.session_lifetime is not None:
if is_guest:
@@ -713,12 +732,7 @@ class RegistrationHandler(BaseHandler):
is_appservice_ghost=is_appservice_ghost,
)
- login_counter.labels(
- guest=is_guest,
- auth_provider=(auth_provider_id or ""),
- ).inc()
-
- return (registered_device_id, access_token)
+ return {"device_id": registered_device_id, "access_token": access_token}
async def post_registration_actions(
self, user_id: str, auth_result: dict, access_token: Optional[str]
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 36071feb36..4ec1bfa6ea 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -61,7 +61,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest = content["is_guest"]
is_appservice_ghost = content["is_appservice_ghost"]
- device_id, access_token = await self.registration_handler.register_device(
+ res = await self.registration_handler.register_device_inner(
user_id,
device_id,
initial_display_name,
@@ -69,7 +69,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_appservice_ghost=is_appservice_ghost,
)
- return 200, {"device_id": device_id, "access_token": access_token}
+ return 200, res
def register_servlets(hs, http_server):
|