diff --git a/changelog.d/14055.misc b/changelog.d/14055.misc
new file mode 100644
index 0000000000..02980bc528
--- /dev/null
+++ b/changelog.d/14055.misc
@@ -0,0 +1 @@
+Add missing type hints to `HomeServer`.
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 816e1a6d79..d74d135c0c 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import Codes, Requester, UserID, create_requester
@@ -76,6 +77,9 @@ class DeactivateAccountHandler:
True if identity server supports removing threepids, otherwise False.
"""
+ # This can only be called on the main process.
+ assert isinstance(self._device_handler, DeviceHandler)
+
# Check if this user can be deactivated
if not await self._third_party_rules.check_can_deactivate_user(
user_id, by_admin
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index da3ddafeae..b1e55e1b9e 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
class DeviceWorkerHandler:
+ device_list_updater: "DeviceListWorkerUpdater"
+
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.hs = hs
@@ -76,6 +78,8 @@ class DeviceWorkerHandler:
self.server_name = hs.hostname
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
+ self.device_list_updater = DeviceListWorkerUpdater(hs)
+
@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
"""
@@ -99,6 +103,19 @@ class DeviceWorkerHandler:
log_kv(device_map)
return devices
+ async def get_dehydrated_device(
+ self, user_id: str
+ ) -> Optional[Tuple[str, JsonDict]]:
+ """Retrieve the information for a dehydrated device.
+
+ Args:
+ user_id: the user whose dehydrated device we are looking for
+ Returns:
+ a tuple whose first item is the device ID, and the second item is
+ the dehydrated device information
+ """
+ return await self.store.get_dehydrated_device(user_id)
+
@trace
async def get_device(self, user_id: str, device_id: str) -> JsonDict:
"""Retrieve the given device
@@ -127,7 +144,7 @@ class DeviceWorkerHandler:
@cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
- ) -> Collection[str]:
+ ) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
@@ -320,6 +337,8 @@ class DeviceWorkerHandler:
class DeviceHandler(DeviceWorkerHandler):
+ device_list_updater: "DeviceListUpdater"
+
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler):
await self.delete_devices(user_id, [old_device_id])
return device_id
- async def get_dehydrated_device(
- self, user_id: str
- ) -> Optional[Tuple[str, JsonDict]]:
- """Retrieve the information for a dehydrated device.
-
- Args:
- user_id: the user whose dehydrated device we are looking for
- Returns:
- a tuple whose first item is the device ID, and the second item is
- the dehydrated device information
- """
- return await self.store.get_dehydrated_device(user_id)
-
async def rehydrate_device(
self, user_id: str, access_token: str, device_id: str
) -> dict:
@@ -882,7 +888,36 @@ def _update_device_from_client_ips(
)
-class DeviceListUpdater:
+class DeviceListWorkerUpdater:
+ "Handles incoming device list updates from federation and contacts the main process over replication"
+
+ def __init__(self, hs: "HomeServer"):
+ from synapse.replication.http.devices import (
+ ReplicationUserDevicesResyncRestServlet,
+ )
+
+ self._user_device_resync_client = (
+ ReplicationUserDevicesResyncRestServlet.make_client(hs)
+ )
+
+ async def user_device_resync(
+ self, user_id: str, mark_failed_as_stale: bool = True
+ ) -> Optional[JsonDict]:
+ """Fetches all devices for a user and updates the device cache with them.
+
+ Args:
+ user_id: The user's id whose device_list will be updated.
+ mark_failed_as_stale: Whether to mark the user's device list as stale
+ if the attempt to resync failed.
+ Returns:
+ A dict with device info as under the "devices" in the result of this
+ request:
+ https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+ """
+ return await self._user_device_resync_client(user_id=user_id)
+
+
+class DeviceListUpdater(DeviceListWorkerUpdater):
"Handles incoming device list updates from federation and updates the DB"
def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index bf1221f523..5fe102e2f2 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -27,9 +27,9 @@ from twisted.internet import defer
from synapse.api.constants import EduTypes
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
-from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
JsonDict,
UserID,
@@ -56,27 +56,23 @@ class E2eKeysHandler:
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
- self._edu_updater = SigningKeyEduUpdater(hs, self)
-
federation_registry = hs.get_federation_registry()
- self._is_master = hs.config.worker.worker_app is None
- if not self._is_master:
- self._user_device_resync_client = (
- ReplicationUserDevicesResyncRestServlet.make_client(hs)
- )
- else:
+ is_master = hs.config.worker.worker_app is None
+ if is_master:
+ edu_updater = SigningKeyEduUpdater(hs)
+
# Only register this edu handler on master as it requires writing
# device updates to the db
federation_registry.register_edu_handler(
EduTypes.SIGNING_KEY_UPDATE,
- self._edu_updater.incoming_signing_key_update,
+ edu_updater.incoming_signing_key_update,
)
# also handle the unstable version
# FIXME: remove this when enough servers have upgraded
federation_registry.register_edu_handler(
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
- self._edu_updater.incoming_signing_key_update,
+ edu_updater.incoming_signing_key_update,
)
# doesn't really work as part of the generic query API, because the
@@ -319,14 +315,13 @@ class E2eKeysHandler:
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
- if self._is_master:
- resync_results = await self.device_handler.device_list_updater.user_device_resync(
+ resync_results = (
+ await self.device_handler.device_list_updater.user_device_resync(
user_id
)
- else:
- resync_results = await self._user_device_resync_client(
- user_id=user_id
- )
+ )
+ if resync_results is None:
+ raise ValueError("Device resync failed")
# Add the device keys to the results.
user_devices = resync_results["devices"]
@@ -605,6 +600,8 @@ class E2eKeysHandler:
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
time_now = self.clock.time_msec()
@@ -732,6 +729,8 @@ class E2eKeysHandler:
user_id: the user uploading the keys
keys: the signing keys
"""
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
# if a master key is uploaded, then check it. Otherwise, load the
# stored master key, to check signatures on other keys
@@ -823,6 +822,9 @@ class E2eKeysHandler:
Raises:
SynapseError: if the signatures dict is not valid.
"""
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
failures = {}
# signatures to be stored. Each item will be a SignatureListItem
@@ -1200,6 +1202,9 @@ class E2eKeysHandler:
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
If the key cannot be retrieved, all values in the tuple will instead be None.
"""
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
try:
remote_result = await self.federation.query_user_devices(
user.domain, user.to_string()
@@ -1396,11 +1401,14 @@ class SignatureListItem:
class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""
- def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
- self.e2e_keys_handler = e2e_keys_handler
+
+ device_handler = hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
+ self._device_handler = device_handler
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
@@ -1445,9 +1453,6 @@ class SigningKeyEduUpdater:
user_id: the user whose updates we are processing
"""
- device_handler = self.e2e_keys_handler.device_handler
- device_list_updater = device_handler.device_list_updater
-
async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
@@ -1459,13 +1464,11 @@ class SigningKeyEduUpdater:
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates:
- new_device_ids = (
- await device_list_updater.process_cross_signing_key_update(
- user_id,
- master_key,
- self_signing_key,
- )
+ new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
+ user_id,
+ master_key,
+ self_signing_key,
)
device_ids = device_ids + new_device_ids
- await device_handler.notify_device_update(user_id, device_ids)
+ await self._device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ca1c7a1866..6307fa9c5d 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -38,6 +38,7 @@ from synapse.api.errors import (
)
from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved
+from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.replication.http.register import (
@@ -841,6 +842,9 @@ class RegistrationHandler:
refresh_token = None
refresh_token_id = None
+ # This can only run on the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
registered_device_id = await self.device_handler.check_device_registered(
user_id,
device_id,
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 73861bbd40..bd9d0bb34b 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -15,6 +15,7 @@ import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, StoreError, SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.types import Requester
if TYPE_CHECKING:
@@ -29,7 +30,10 @@ class SetPasswordHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._auth_handler = hs.get_auth_handler()
- self._device_handler = hs.get_device_handler()
+ # This can only be instantiated on the main process.
+ device_handler = hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
+ self._device_handler = device_handler
async def set_password(
self,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 749d7e93b0..e1c0bff1b2 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -37,6 +37,7 @@ from twisted.web.server import Request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
+from synapse.handlers.device import DeviceHandler
from synapse.handlers.register import init_counters_for_auth_provider
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
@@ -1035,6 +1036,8 @@ class SsoHandler:
) -> None:
"""Revoke any devices and in-flight logins tied to a provider session.
+ Can only be called from the main process.
+
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
@@ -1042,6 +1045,12 @@ class SsoHandler:
expected_user_id: The user we're expecting to logout. If set, it will ignore
sessions belonging to other users and log an error.
"""
+
+ # It is expected that this is the main process.
+ assert isinstance(
+ self._device_handler, DeviceHandler
+ ), "revoking SSO sessions can only be called on the main process"
+
# Invalidate any running user-mapping sessions
to_delete = []
for session_id, session in self._username_mapping_sessions.items():
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 1adc1fd64f..96a661177a 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -86,6 +86,7 @@ from synapse.handlers.auth import (
ON_LOGGED_OUT_CALLBACK,
AuthHandler,
)
+from synapse.handlers.device import DeviceHandler
from synapse.handlers.push_rules import RuleSpec, check_actions
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
@@ -207,6 +208,7 @@ class ModuleApi:
self._registration_handler = hs.get_registration_handler()
self._send_email_handler = hs.get_send_email_handler()
self._push_rules_handler = hs.get_push_rules_handler()
+ self._device_handler = hs.get_device_handler()
self.custom_template_dir = hs.config.server.custom_template_directory
try:
@@ -784,6 +786,8 @@ class ModuleApi:
) -> Generator["defer.Deferred[Any]", Any, None]:
"""Invalidate an access token for a user
+ Can only be called from the main process.
+
Added in Synapse v0.25.0.
Args:
@@ -796,6 +800,10 @@ class ModuleApi:
Raises:
synapse.api.errors.AuthError: the access token is invalid
"""
+ assert isinstance(
+ self._device_handler, DeviceHandler
+ ), "invalidate_access_token can only be called on the main process"
+
# see if the access token corresponds to a device
user_info = yield defer.ensureDeferred(
self._auth.get_user_by_access_token(access_token)
@@ -805,7 +813,7 @@ class ModuleApi:
if device_id:
# delete the device, which will also delete its access tokens
yield defer.ensureDeferred(
- self._hs.get_device_handler().delete_devices(user_id, [device_id])
+ self._device_handler.delete_devices(user_id, [device_id])
)
else:
# no associated device. Just delete the access token.
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index c21629def8..7c4941c3d3 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
from twisted.web.server import Request
@@ -63,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.device_list_updater = hs.get_device_handler().device_list_updater
+ from synapse.handlers.device import DeviceHandler
+
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_list_updater = handler.device_list_updater
+
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
@@ -73,7 +78,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
- ) -> Tuple[int, JsonDict]:
+ ) -> Tuple[int, Optional[JsonDict]]:
user_devices = await self.device_list_updater.user_device_resync(user_id)
return 200, user_devices
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index c62ea22116..fb73886df0 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -238,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
"""
Register all the admin servlets.
"""
+ # Admin servlets aren't registered on workers.
+ if hs.config.worker.worker_app is not None:
+ return
+
register_servlets_for_client_rest_resource(hs, http_server)
BlockRoomRestServlet(hs).register(http_server)
ListRoomRestServlet(hs).register(http_server)
@@ -254,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserTokenRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
- DeviceRestServlet(hs).register(http_server)
- DevicesRestServlet(hs).register(http_server)
- DeleteDevicesRestServlet(hs).register(http_server)
UserMediaStatisticsRestServlet(hs).register(http_server)
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
@@ -280,12 +281,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserByExternalId(hs).register(http_server)
UserByThreePid(hs).register(http_server)
- # Some servlets only get registered for the main process.
- if hs.config.worker.worker_app is None:
- SendServerNoticeServlet(hs).register(http_server)
- BackgroundUpdateEnabledRestServlet(hs).register(http_server)
- BackgroundUpdateRestServlet(hs).register(http_server)
- BackgroundUpdateStartJobRestServlet(hs).register(http_server)
+ DeviceRestServlet(hs).register(http_server)
+ DevicesRestServlet(hs).register(http_server)
+ DeleteDevicesRestServlet(hs).register(http_server)
+ SendServerNoticeServlet(hs).register(http_server)
+ BackgroundUpdateEnabledRestServlet(hs).register(http_server)
+ BackgroundUpdateRestServlet(hs).register(http_server)
+ BackgroundUpdateStartJobRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(
@@ -294,9 +296,11 @@ def register_servlets_for_client_rest_resource(
"""Register only the servlets which need to be exposed on /_matrix/client/xxx"""
WhoisRestServlet(hs).register(http_server)
PurgeHistoryStatusRestServlet(hs).register(http_server)
- DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server)
- ResetPasswordRestServlet(hs).register(http_server)
+ # The following resources can only be run on the main process.
+ if hs.config.worker.worker_app is None:
+ DeactivateAccountRestServlet(hs).register(http_server)
+ ResetPasswordRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
AccountValidityRenewServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index d934880102..3b2f2d9abb 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -16,6 +16,7 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import NotFoundError, SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine
@@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine
@@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 8f3cbd4ea2..69b803f9f8 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -20,6 +20,7 @@ from pydantic import Extra, StrictStr
from synapse.api import errors
from synapse.api.errors import NotFoundError
+from synapse.handlers.device import DeviceHandler
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -80,7 +81,9 @@ class DeleteDevicesRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
self.auth_handler = hs.get_auth_handler()
class PostBody(RequestBodyModel):
@@ -125,7 +128,9 @@ class DeviceRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
self.auth_handler = hs.get_auth_handler()
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
@@ -256,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
@@ -313,7 +320,9 @@ class ClaimDehydratedDeviceServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
class PostBody(RequestBodyModel):
device_id: StrictStr
diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py
index 23dfa4518f..6d34625ad5 100644
--- a/synapse/rest/client/logout.py
+++ b/synapse/rest/client/logout.py
@@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
+from synapse.handlers.device import DeviceHandler
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
@@ -34,7 +35,9 @@ class LogoutRestServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
- self._device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self._device_handler = handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
@@ -59,7 +62,9 @@ class LogoutAllRestServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
- self._device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self._device_handler = handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
diff --git a/synapse/server.py b/synapse/server.py
index f0a60d0056..5baae2325e 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -510,7 +510,7 @@ class HomeServer(metaclass=abc.ABCMeta):
)
@cache_in_self
- def get_device_handler(self):
+ def get_device_handler(self) -> DeviceWorkerHandler:
if self.config.worker.worker_app:
return DeviceWorkerHandler(self)
else:
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index b8b465d35b..ce7525e29c 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -19,7 +19,7 @@ from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError, SynapseError
-from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
+from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
from synapse.server import HomeServer
from synapse.util import Clock
@@ -32,7 +32,9 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
- self.handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.handler = handler
self.store = hs.get_datastores().main
return hs
@@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(res, "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
+ assert dev is not None
self.assertEqual(dev["display_name"], "display name")
def test_device_is_preserved_if_exists(self) -> None:
@@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(res2, "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
+ assert dev is not None
self.assertEqual(dev["display_name"], "display name")
def test_device_id_is_made_up_if_unspecified(self) -> None:
@@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
+ assert dev is not None
self.assertEqual(dev["display_name"], "display")
def test_get_devices_by_user(self) -> None:
@@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
- self.handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.handler = handler
self.registration = hs.get_registration_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
@@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
)
)
- retrieved_device_id, device_data = self.get_success(
- self.handler.get_dehydrated_device(user_id=user_id)
- )
+ result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
+ assert result is not None
+ retrieved_device_id, device_data = result
self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index d52aee8f92..03f2112b07 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -19,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes
+from synapse.handlers.device import DeviceHandler
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
@@ -34,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.handler = handler
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
|