summary refs log tree commit diff
path: root/synapse/handlers/register.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/register.py')
-rw-r--r--synapse/handlers/register.py46
1 files changed, 30 insertions, 16 deletions
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index b66f8756b8..cd001e87c7 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -16,7 +16,7 @@
 """Contains functions for registering clients."""
 
 import logging
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
 
 from prometheus_client import Counter
 
@@ -82,6 +82,7 @@ class RegistrationHandler(BaseHandler):
             )
         else:
             self.device_handler = hs.get_device_handler()
+            self._register_device_client = self.register_device_inner
             self.pusher_pool = hs.get_pusherpool()
 
         self.session_lifetime = hs.config.session_lifetime
@@ -678,17 +679,35 @@ class RegistrationHandler(BaseHandler):
         Returns:
             Tuple of device ID and access token
         """
+        res = await self._register_device_client(
+            user_id=user_id,
+            device_id=device_id,
+            initial_display_name=initial_display_name,
+            is_guest=is_guest,
+            is_appservice_ghost=is_appservice_ghost,
+        )
 
-        if self.hs.config.worker_app:
-            r = await self._register_device_client(
-                user_id=user_id,
-                device_id=device_id,
-                initial_display_name=initial_display_name,
-                is_guest=is_guest,
-                is_appservice_ghost=is_appservice_ghost,
-            )
-            return r["device_id"], r["access_token"]
+        login_counter.labels(
+            guest=is_guest,
+            auth_provider=(auth_provider_id or ""),
+        ).inc()
+
+        return res["device_id"], res["access_token"]
+
+    async def register_device_inner(
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        initial_display_name: Optional[str],
+        is_guest: bool = False,
+        is_appservice_ghost: bool = False,
+    ) -> Dict[str, str]:
+        """Helper for register_device
 
+        Does the bits that need doing on the main process. Not for use outside this
+        class and RegisterDeviceReplicationServlet.
+        """
+        assert not self.hs.config.worker_app
         valid_until_ms = None
         if self.session_lifetime is not None:
             if is_guest:
@@ -713,12 +732,7 @@ class RegistrationHandler(BaseHandler):
                 is_appservice_ghost=is_appservice_ghost,
             )
 
-        login_counter.labels(
-            guest=is_guest,
-            auth_provider=(auth_provider_id or ""),
-        ).inc()
-
-        return (registered_device_id, access_token)
+        return {"device_id": registered_device_id, "access_token": access_token}
 
     async def post_registration_actions(
         self, user_id: str, auth_result: dict, access_token: Optional[str]