summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/14055.misc1
-rw-r--r--synapse/handlers/deactivate_account.py4
-rw-r--r--synapse/handlers/device.py65
-rw-r--r--synapse/handlers/e2e_keys.py61
-rw-r--r--synapse/handlers/register.py4
-rw-r--r--synapse/handlers/set_password.py6
-rw-r--r--synapse/handlers/sso.py9
-rw-r--r--synapse/module_api/__init__.py10
-rw-r--r--synapse/replication/http/devices.py11
-rw-r--r--synapse/rest/admin/__init__.py26
-rw-r--r--synapse/rest/admin/devices.py13
-rw-r--r--synapse/rest/client/devices.py17
-rw-r--r--synapse/rest/client/logout.py9
-rw-r--r--synapse/server.py2
-rw-r--r--tests/handlers/test_device.py19
-rw-r--r--tests/rest/admin/test_device.py5
16 files changed, 185 insertions, 77 deletions
diff --git a/changelog.d/14055.misc b/changelog.d/14055.misc
new file mode 100644
index 0000000000..02980bc528
--- /dev/null
+++ b/changelog.d/14055.misc
@@ -0,0 +1 @@
+Add missing type hints to `HomeServer`.
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 816e1a6d79..d74d135c0c 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -16,6 +16,7 @@ import logging
 from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import SynapseError
+from synapse.handlers.device import DeviceHandler
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import Codes, Requester, UserID, create_requester
 
@@ -76,6 +77,9 @@ class DeactivateAccountHandler:
             True if identity server supports removing threepids, otherwise False.
         """
 
+        # This can only be called on the main process.
+        assert isinstance(self._device_handler, DeviceHandler)
+
         # Check if this user can be deactivated
         if not await self._third_party_rules.check_can_deactivate_user(
             user_id, by_admin
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index da3ddafeae..b1e55e1b9e 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
 
 
 class DeviceWorkerHandler:
+    device_list_updater: "DeviceListWorkerUpdater"
+
     def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.hs = hs
@@ -76,6 +78,8 @@ class DeviceWorkerHandler:
         self.server_name = hs.hostname
         self._msc3852_enabled = hs.config.experimental.msc3852_enabled
 
+        self.device_list_updater = DeviceListWorkerUpdater(hs)
+
     @trace
     async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
         """
@@ -99,6 +103,19 @@ class DeviceWorkerHandler:
         log_kv(device_map)
         return devices
 
+    async def get_dehydrated_device(
+        self, user_id: str
+    ) -> Optional[Tuple[str, JsonDict]]:
+        """Retrieve the information for a dehydrated device.
+
+        Args:
+            user_id: the user whose dehydrated device we are looking for
+        Returns:
+            a tuple whose first item is the device ID, and the second item is
+            the dehydrated device information
+        """
+        return await self.store.get_dehydrated_device(user_id)
+
     @trace
     async def get_device(self, user_id: str, device_id: str) -> JsonDict:
         """Retrieve the given device
@@ -127,7 +144,7 @@ class DeviceWorkerHandler:
     @cancellable
     async def get_device_changes_in_shared_rooms(
         self, user_id: str, room_ids: Collection[str], from_token: StreamToken
-    ) -> Collection[str]:
+    ) -> Set[str]:
         """Get the set of users whose devices have changed who share a room with
         the given user.
         """
@@ -320,6 +337,8 @@ class DeviceWorkerHandler:
 
 
 class DeviceHandler(DeviceWorkerHandler):
+    device_list_updater: "DeviceListUpdater"
+
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
@@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler):
             await self.delete_devices(user_id, [old_device_id])
         return device_id
 
-    async def get_dehydrated_device(
-        self, user_id: str
-    ) -> Optional[Tuple[str, JsonDict]]:
-        """Retrieve the information for a dehydrated device.
-
-        Args:
-            user_id: the user whose dehydrated device we are looking for
-        Returns:
-            a tuple whose first item is the device ID, and the second item is
-            the dehydrated device information
-        """
-        return await self.store.get_dehydrated_device(user_id)
-
     async def rehydrate_device(
         self, user_id: str, access_token: str, device_id: str
     ) -> dict:
@@ -882,7 +888,36 @@ def _update_device_from_client_ips(
     )
 
 
-class DeviceListUpdater:
+class DeviceListWorkerUpdater:
+    "Handles incoming device list updates from federation and contacts the main process over replication"
+
+    def __init__(self, hs: "HomeServer"):
+        from synapse.replication.http.devices import (
+            ReplicationUserDevicesResyncRestServlet,
+        )
+
+        self._user_device_resync_client = (
+            ReplicationUserDevicesResyncRestServlet.make_client(hs)
+        )
+
+    async def user_device_resync(
+        self, user_id: str, mark_failed_as_stale: bool = True
+    ) -> Optional[JsonDict]:
+        """Fetches all devices for a user and updates the device cache with them.
+
+        Args:
+            user_id: The user's id whose device_list will be updated.
+            mark_failed_as_stale: Whether to mark the user's device list as stale
+                if the attempt to resync failed.
+        Returns:
+            A dict with device info as under the "devices" in the result of this
+            request:
+            https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+        """
+        return await self._user_device_resync_client(user_id=user_id)
+
+
+class DeviceListUpdater(DeviceListWorkerUpdater):
     "Handles incoming device list updates from federation and updates the DB"
 
     def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index bf1221f523..5fe102e2f2 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -27,9 +27,9 @@ from twisted.internet import defer
 
 from synapse.api.constants import EduTypes
 from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
+from synapse.handlers.device import DeviceHandler
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
-from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.types import (
     JsonDict,
     UserID,
@@ -56,27 +56,23 @@ class E2eKeysHandler:
         self.is_mine = hs.is_mine
         self.clock = hs.get_clock()
 
-        self._edu_updater = SigningKeyEduUpdater(hs, self)
-
         federation_registry = hs.get_federation_registry()
 
-        self._is_master = hs.config.worker.worker_app is None
-        if not self._is_master:
-            self._user_device_resync_client = (
-                ReplicationUserDevicesResyncRestServlet.make_client(hs)
-            )
-        else:
+        is_master = hs.config.worker.worker_app is None
+        if is_master:
+            edu_updater = SigningKeyEduUpdater(hs)
+
             # Only register this edu handler on master as it requires writing
             # device updates to the db
             federation_registry.register_edu_handler(
                 EduTypes.SIGNING_KEY_UPDATE,
-                self._edu_updater.incoming_signing_key_update,
+                edu_updater.incoming_signing_key_update,
             )
             # also handle the unstable version
             # FIXME: remove this when enough servers have upgraded
             federation_registry.register_edu_handler(
                 EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
-                self._edu_updater.incoming_signing_key_update,
+                edu_updater.incoming_signing_key_update,
             )
 
         # doesn't really work as part of the generic query API, because the
@@ -319,14 +315,13 @@ class E2eKeysHandler:
             # probably be tracking their device lists. However, we haven't
             # done an initial sync on the device list so we do it now.
             try:
-                if self._is_master:
-                    resync_results = await self.device_handler.device_list_updater.user_device_resync(
+                resync_results = (
+                    await self.device_handler.device_list_updater.user_device_resync(
                         user_id
                     )
-                else:
-                    resync_results = await self._user_device_resync_client(
-                        user_id=user_id
-                    )
+                )
+                if resync_results is None:
+                    raise ValueError("Device resync failed")
 
                 # Add the device keys to the results.
                 user_devices = resync_results["devices"]
@@ -605,6 +600,8 @@ class E2eKeysHandler:
     async def upload_keys_for_user(
         self, user_id: str, device_id: str, keys: JsonDict
     ) -> JsonDict:
+        # This can only be called from the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
 
         time_now = self.clock.time_msec()
 
@@ -732,6 +729,8 @@ class E2eKeysHandler:
             user_id: the user uploading the keys
             keys: the signing keys
         """
+        # This can only be called from the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
 
         # if a master key is uploaded, then check it.  Otherwise, load the
         # stored master key, to check signatures on other keys
@@ -823,6 +822,9 @@ class E2eKeysHandler:
         Raises:
             SynapseError: if the signatures dict is not valid.
         """
+        # This can only be called from the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
+
         failures = {}
 
         # signatures to be stored.  Each item will be a SignatureListItem
@@ -1200,6 +1202,9 @@ class E2eKeysHandler:
             A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
             If the key cannot be retrieved, all values in the tuple will instead be None.
         """
+        # This can only be called from the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
+
         try:
             remote_result = await self.federation.query_user_devices(
                 user.domain, user.to_string()
@@ -1396,11 +1401,14 @@ class SignatureListItem:
 class SigningKeyEduUpdater:
     """Handles incoming signing key updates from federation and updates the DB"""
 
-    def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
         self.federation = hs.get_federation_client()
         self.clock = hs.get_clock()
-        self.e2e_keys_handler = e2e_keys_handler
+
+        device_handler = hs.get_device_handler()
+        assert isinstance(device_handler, DeviceHandler)
+        self._device_handler = device_handler
 
         self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
 
@@ -1445,9 +1453,6 @@ class SigningKeyEduUpdater:
             user_id: the user whose updates we are processing
         """
 
-        device_handler = self.e2e_keys_handler.device_handler
-        device_list_updater = device_handler.device_list_updater
-
         async with self._remote_edu_linearizer.queue(user_id):
             pending_updates = self._pending_updates.pop(user_id, [])
             if not pending_updates:
@@ -1459,13 +1464,11 @@ class SigningKeyEduUpdater:
             logger.info("pending updates: %r", pending_updates)
 
             for master_key, self_signing_key in pending_updates:
-                new_device_ids = (
-                    await device_list_updater.process_cross_signing_key_update(
-                        user_id,
-                        master_key,
-                        self_signing_key,
-                    )
+                new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
+                    user_id,
+                    master_key,
+                    self_signing_key,
                 )
                 device_ids = device_ids + new_device_ids
 
-            await device_handler.notify_device_update(user_id, device_ids)
+            await self._device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ca1c7a1866..6307fa9c5d 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -38,6 +38,7 @@ from synapse.api.errors import (
 )
 from synapse.appservice import ApplicationService
 from synapse.config.server import is_threepid_reserved
+from synapse.handlers.device import DeviceHandler
 from synapse.http.servlet import assert_params_in_dict
 from synapse.replication.http.login import RegisterDeviceReplicationServlet
 from synapse.replication.http.register import (
@@ -841,6 +842,9 @@ class RegistrationHandler:
         refresh_token = None
         refresh_token_id = None
 
+        # This can only run on the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
+
         registered_device_id = await self.device_handler.check_device_registered(
             user_id,
             device_id,
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 73861bbd40..bd9d0bb34b 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -15,6 +15,7 @@ import logging
 from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import Codes, StoreError, SynapseError
+from synapse.handlers.device import DeviceHandler
 from synapse.types import Requester
 
 if TYPE_CHECKING:
@@ -29,7 +30,10 @@ class SetPasswordHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
         self._auth_handler = hs.get_auth_handler()
-        self._device_handler = hs.get_device_handler()
+        # This can only be instantiated on the main process.
+        device_handler = hs.get_device_handler()
+        assert isinstance(device_handler, DeviceHandler)
+        self._device_handler = device_handler
 
     async def set_password(
         self,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 749d7e93b0..e1c0bff1b2 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -37,6 +37,7 @@ from twisted.web.server import Request
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
 from synapse.config.sso import SsoAttributeRequirement
+from synapse.handlers.device import DeviceHandler
 from synapse.handlers.register import init_counters_for_auth_provider
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http import get_request_user_agent
@@ -1035,6 +1036,8 @@ class SsoHandler:
     ) -> None:
         """Revoke any devices and in-flight logins tied to a provider session.
 
+        Can only be called from the main process.
+
         Args:
             auth_provider_id: A unique identifier for this SSO provider, e.g.
                 "oidc" or "saml".
@@ -1042,6 +1045,12 @@ class SsoHandler:
             expected_user_id: The user we're expecting to logout. If set, it will ignore
                 sessions belonging to other users and log an error.
         """
+
+        # It is expected that this is the main process.
+        assert isinstance(
+            self._device_handler, DeviceHandler
+        ), "revoking SSO sessions can only be called on the main process"
+
         # Invalidate any running user-mapping sessions
         to_delete = []
         for session_id, session in self._username_mapping_sessions.items():
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 1adc1fd64f..96a661177a 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -86,6 +86,7 @@ from synapse.handlers.auth import (
     ON_LOGGED_OUT_CALLBACK,
     AuthHandler,
 )
+from synapse.handlers.device import DeviceHandler
 from synapse.handlers.push_rules import RuleSpec, check_actions
 from synapse.http.client import SimpleHttpClient
 from synapse.http.server import (
@@ -207,6 +208,7 @@ class ModuleApi:
         self._registration_handler = hs.get_registration_handler()
         self._send_email_handler = hs.get_send_email_handler()
         self._push_rules_handler = hs.get_push_rules_handler()
+        self._device_handler = hs.get_device_handler()
         self.custom_template_dir = hs.config.server.custom_template_directory
 
         try:
@@ -784,6 +786,8 @@ class ModuleApi:
     ) -> Generator["defer.Deferred[Any]", Any, None]:
         """Invalidate an access token for a user
 
+        Can only be called from the main process.
+
         Added in Synapse v0.25.0.
 
         Args:
@@ -796,6 +800,10 @@ class ModuleApi:
         Raises:
             synapse.api.errors.AuthError: the access token is invalid
         """
+        assert isinstance(
+            self._device_handler, DeviceHandler
+        ), "invalidate_access_token can only be called on the main process"
+
         # see if the access token corresponds to a device
         user_info = yield defer.ensureDeferred(
             self._auth.get_user_by_access_token(access_token)
@@ -805,7 +813,7 @@ class ModuleApi:
         if device_id:
             # delete the device, which will also delete its access tokens
             yield defer.ensureDeferred(
-                self._hs.get_device_handler().delete_devices(user_id, [device_id])
+                self._device_handler.delete_devices(user_id, [device_id])
             )
         else:
             # no associated device. Just delete the access token.
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index c21629def8..7c4941c3d3 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
 
 from twisted.web.server import Request
 
@@ -63,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.device_list_updater = hs.get_device_handler().device_list_updater
+        from synapse.handlers.device import DeviceHandler
+
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.device_list_updater = handler.device_list_updater
+
         self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
 
@@ -73,7 +78,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
 
     async def _handle_request(  # type: ignore[override]
         self, request: Request, user_id: str
-    ) -> Tuple[int, JsonDict]:
+    ) -> Tuple[int, Optional[JsonDict]]:
         user_devices = await self.device_list_updater.user_device_resync(user_id)
 
         return 200, user_devices
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index c62ea22116..fb73886df0 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -238,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     """
     Register all the admin servlets.
     """
+    # Admin servlets aren't registered on workers.
+    if hs.config.worker.worker_app is not None:
+        return
+
     register_servlets_for_client_rest_resource(hs, http_server)
     BlockRoomRestServlet(hs).register(http_server)
     ListRoomRestServlet(hs).register(http_server)
@@ -254,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     UserTokenRestServlet(hs).register(http_server)
     UserRestServletV2(hs).register(http_server)
     UsersRestServletV2(hs).register(http_server)
-    DeviceRestServlet(hs).register(http_server)
-    DevicesRestServlet(hs).register(http_server)
-    DeleteDevicesRestServlet(hs).register(http_server)
     UserMediaStatisticsRestServlet(hs).register(http_server)
     EventReportDetailRestServlet(hs).register(http_server)
     EventReportsRestServlet(hs).register(http_server)
@@ -280,12 +281,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     UserByExternalId(hs).register(http_server)
     UserByThreePid(hs).register(http_server)
 
-    # Some servlets only get registered for the main process.
-    if hs.config.worker.worker_app is None:
-        SendServerNoticeServlet(hs).register(http_server)
-        BackgroundUpdateEnabledRestServlet(hs).register(http_server)
-        BackgroundUpdateRestServlet(hs).register(http_server)
-        BackgroundUpdateStartJobRestServlet(hs).register(http_server)
+    DeviceRestServlet(hs).register(http_server)
+    DevicesRestServlet(hs).register(http_server)
+    DeleteDevicesRestServlet(hs).register(http_server)
+    SendServerNoticeServlet(hs).register(http_server)
+    BackgroundUpdateEnabledRestServlet(hs).register(http_server)
+    BackgroundUpdateRestServlet(hs).register(http_server)
+    BackgroundUpdateStartJobRestServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(
@@ -294,9 +296,11 @@ def register_servlets_for_client_rest_resource(
     """Register only the servlets which need to be exposed on /_matrix/client/xxx"""
     WhoisRestServlet(hs).register(http_server)
     PurgeHistoryStatusRestServlet(hs).register(http_server)
-    DeactivateAccountRestServlet(hs).register(http_server)
     PurgeHistoryRestServlet(hs).register(http_server)
-    ResetPasswordRestServlet(hs).register(http_server)
+    # The following resources can only be run on the main process.
+    if hs.config.worker.worker_app is None:
+        DeactivateAccountRestServlet(hs).register(http_server)
+        ResetPasswordRestServlet(hs).register(http_server)
     SearchUsersRestServlet(hs).register(http_server)
     UserRegisterServlet(hs).register(http_server)
     AccountValidityRenewServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index d934880102..3b2f2d9abb 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -16,6 +16,7 @@ from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import NotFoundError, SynapseError
+from synapse.handlers.device import DeviceHandler
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
@@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
-        self.device_handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.device_handler = handler
         self.store = hs.get_datastores().main
         self.is_mine = hs.is_mine
 
@@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
-        self.device_handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.device_handler = handler
         self.store = hs.get_datastores().main
         self.is_mine = hs.is_mine
 
@@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
-        self.device_handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.device_handler = handler
         self.store = hs.get_datastores().main
         self.is_mine = hs.is_mine
 
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 8f3cbd4ea2..69b803f9f8 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -20,6 +20,7 @@ from pydantic import Extra, StrictStr
 
 from synapse.api import errors
 from synapse.api.errors import NotFoundError
+from synapse.handlers.device import DeviceHandler
 from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
@@ -80,7 +81,9 @@ class DeleteDevicesRestServlet(RestServlet):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.device_handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.device_handler = handler
         self.auth_handler = hs.get_auth_handler()
 
     class PostBody(RequestBodyModel):
@@ -125,7 +128,9 @@ class DeviceRestServlet(RestServlet):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.device_handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.device_handler = handler
         self.auth_handler = hs.get_auth_handler()
         self._msc3852_enabled = hs.config.experimental.msc3852_enabled
 
@@ -256,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.device_handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.device_handler = handler
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
@@ -313,7 +320,9 @@ class ClaimDehydratedDeviceServlet(RestServlet):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.device_handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.device_handler = handler
 
     class PostBody(RequestBodyModel):
         device_id: StrictStr
diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py
index 23dfa4518f..6d34625ad5 100644
--- a/synapse/rest/client/logout.py
+++ b/synapse/rest/client/logout.py
@@ -15,6 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Tuple
 
+from synapse.handlers.device import DeviceHandler
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet
 from synapse.http.site import SynapseRequest
@@ -34,7 +35,9 @@ class LogoutRestServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
-        self._device_handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self._device_handler = handler
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_expired=True)
@@ -59,7 +62,9 @@ class LogoutAllRestServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
-        self._device_handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self._device_handler = handler
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_expired=True)
diff --git a/synapse/server.py b/synapse/server.py
index f0a60d0056..5baae2325e 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -510,7 +510,7 @@ class HomeServer(metaclass=abc.ABCMeta):
         )
 
     @cache_in_self
-    def get_device_handler(self):
+    def get_device_handler(self) -> DeviceWorkerHandler:
         if self.config.worker.worker_app:
             return DeviceWorkerHandler(self)
         else:
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index b8b465d35b..ce7525e29c 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -19,7 +19,7 @@ from typing import Optional
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import NotFoundError, SynapseError
-from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
+from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
 from synapse.server import HomeServer
 from synapse.util import Clock
 
@@ -32,7 +32,9 @@ user2 = "@theresa:bbb"
 class DeviceTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver("server", federation_http_client=None)
-        self.handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.handler = handler
         self.store = hs.get_datastores().main
         return hs
 
@@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         self.assertEqual(res, "fco")
 
         dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
+        assert dev is not None
         self.assertEqual(dev["display_name"], "display name")
 
     def test_device_is_preserved_if_exists(self) -> None:
@@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         self.assertEqual(res2, "fco")
 
         dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
+        assert dev is not None
         self.assertEqual(dev["display_name"], "display name")
 
     def test_device_id_is_made_up_if_unspecified(self) -> None:
@@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         )
 
         dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
+        assert dev is not None
         self.assertEqual(dev["display_name"], "display")
 
     def test_get_devices_by_user(self) -> None:
@@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 class DehydrationTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver("server", federation_http_client=None)
-        self.handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.handler = handler
         self.registration = hs.get_registration_handler()
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
@@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        retrieved_device_id, device_data = self.get_success(
-            self.handler.get_dehydrated_device(user_id=user_id)
-        )
+        result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
+        assert result is not None
+        retrieved_device_id, device_data = result
 
         self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
         self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index d52aee8f92..03f2112b07 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -19,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.api.errors import Codes
+from synapse.handlers.device import DeviceHandler
 from synapse.rest.client import login
 from synapse.server import HomeServer
 from synapse.util import Clock
@@ -34,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.handler = handler
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")