summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-11-22 14:08:04 -0500
committerGitHub <noreply@github.com>2022-11-22 14:08:04 -0500
commit6d47b7e32589e816eb766446cc1ff19ea73fc7c1 (patch)
tree3990c590d8fe0d442233790779b7fa62816ca2b2 /synapse/handlers
parentApply correct editorconfig to .pyi files (#14526) (diff)
downloadsynapse-6d47b7e32589e816eb766446cc1ff19ea73fc7c1.tar.xz
Add a type hint for `get_device_handler()` and fix incorrect types. (#14055)
This was the last untyped handler from the HomeServer object. Since
it was being treated as Any (and thus unchecked) it was being used
incorrectly in a few places.
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/deactivate_account.py4
-rw-r--r--synapse/handlers/device.py65
-rw-r--r--synapse/handlers/e2e_keys.py61
-rw-r--r--synapse/handlers/register.py4
-rw-r--r--synapse/handlers/set_password.py6
-rw-r--r--synapse/handlers/sso.py9
6 files changed, 104 insertions, 45 deletions
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 816e1a6d79..d74d135c0c 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -16,6 +16,7 @@ import logging
 from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import SynapseError
+from synapse.handlers.device import DeviceHandler
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import Codes, Requester, UserID, create_requester
 
@@ -76,6 +77,9 @@ class DeactivateAccountHandler:
             True if identity server supports removing threepids, otherwise False.
         """
 
+        # This can only be called on the main process.
+        assert isinstance(self._device_handler, DeviceHandler)
+
         # Check if this user can be deactivated
         if not await self._third_party_rules.check_can_deactivate_user(
             user_id, by_admin
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index da3ddafeae..b1e55e1b9e 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
 
 
 class DeviceWorkerHandler:
+    device_list_updater: "DeviceListWorkerUpdater"
+
     def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.hs = hs
@@ -76,6 +78,8 @@ class DeviceWorkerHandler:
         self.server_name = hs.hostname
         self._msc3852_enabled = hs.config.experimental.msc3852_enabled
 
+        self.device_list_updater = DeviceListWorkerUpdater(hs)
+
     @trace
     async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
         """
@@ -99,6 +103,19 @@ class DeviceWorkerHandler:
         log_kv(device_map)
         return devices
 
+    async def get_dehydrated_device(
+        self, user_id: str
+    ) -> Optional[Tuple[str, JsonDict]]:
+        """Retrieve the information for a dehydrated device.
+
+        Args:
+            user_id: the user whose dehydrated device we are looking for
+        Returns:
+            a tuple whose first item is the device ID, and the second item is
+            the dehydrated device information
+        """
+        return await self.store.get_dehydrated_device(user_id)
+
     @trace
     async def get_device(self, user_id: str, device_id: str) -> JsonDict:
         """Retrieve the given device
@@ -127,7 +144,7 @@ class DeviceWorkerHandler:
     @cancellable
     async def get_device_changes_in_shared_rooms(
         self, user_id: str, room_ids: Collection[str], from_token: StreamToken
-    ) -> Collection[str]:
+    ) -> Set[str]:
         """Get the set of users whose devices have changed who share a room with
         the given user.
         """
@@ -320,6 +337,8 @@ class DeviceWorkerHandler:
 
 
 class DeviceHandler(DeviceWorkerHandler):
+    device_list_updater: "DeviceListUpdater"
+
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
@@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler):
             await self.delete_devices(user_id, [old_device_id])
         return device_id
 
-    async def get_dehydrated_device(
-        self, user_id: str
-    ) -> Optional[Tuple[str, JsonDict]]:
-        """Retrieve the information for a dehydrated device.
-
-        Args:
-            user_id: the user whose dehydrated device we are looking for
-        Returns:
-            a tuple whose first item is the device ID, and the second item is
-            the dehydrated device information
-        """
-        return await self.store.get_dehydrated_device(user_id)
-
     async def rehydrate_device(
         self, user_id: str, access_token: str, device_id: str
     ) -> dict:
@@ -882,7 +888,36 @@ def _update_device_from_client_ips(
     )
 
 
-class DeviceListUpdater:
+class DeviceListWorkerUpdater:
+    "Handles incoming device list updates from federation and contacts the main process over replication"
+
+    def __init__(self, hs: "HomeServer"):
+        from synapse.replication.http.devices import (
+            ReplicationUserDevicesResyncRestServlet,
+        )
+
+        self._user_device_resync_client = (
+            ReplicationUserDevicesResyncRestServlet.make_client(hs)
+        )
+
+    async def user_device_resync(
+        self, user_id: str, mark_failed_as_stale: bool = True
+    ) -> Optional[JsonDict]:
+        """Fetches all devices for a user and updates the device cache with them.
+
+        Args:
+            user_id: The user's id whose device_list will be updated.
+            mark_failed_as_stale: Whether to mark the user's device list as stale
+                if the attempt to resync failed.
+        Returns:
+            A dict with device info as under the "devices" in the result of this
+            request:
+            https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+        """
+        return await self._user_device_resync_client(user_id=user_id)
+
+
+class DeviceListUpdater(DeviceListWorkerUpdater):
     "Handles incoming device list updates from federation and updates the DB"
 
     def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index bf1221f523..5fe102e2f2 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -27,9 +27,9 @@ from twisted.internet import defer
 
 from synapse.api.constants import EduTypes
 from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
+from synapse.handlers.device import DeviceHandler
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
-from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.types import (
     JsonDict,
     UserID,
@@ -56,27 +56,23 @@ class E2eKeysHandler:
         self.is_mine = hs.is_mine
         self.clock = hs.get_clock()
 
-        self._edu_updater = SigningKeyEduUpdater(hs, self)
-
         federation_registry = hs.get_federation_registry()
 
-        self._is_master = hs.config.worker.worker_app is None
-        if not self._is_master:
-            self._user_device_resync_client = (
-                ReplicationUserDevicesResyncRestServlet.make_client(hs)
-            )
-        else:
+        is_master = hs.config.worker.worker_app is None
+        if is_master:
+            edu_updater = SigningKeyEduUpdater(hs)
+
             # Only register this edu handler on master as it requires writing
             # device updates to the db
             federation_registry.register_edu_handler(
                 EduTypes.SIGNING_KEY_UPDATE,
-                self._edu_updater.incoming_signing_key_update,
+                edu_updater.incoming_signing_key_update,
             )
             # also handle the unstable version
             # FIXME: remove this when enough servers have upgraded
             federation_registry.register_edu_handler(
                 EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
-                self._edu_updater.incoming_signing_key_update,
+                edu_updater.incoming_signing_key_update,
             )
 
         # doesn't really work as part of the generic query API, because the
@@ -319,14 +315,13 @@ class E2eKeysHandler:
             # probably be tracking their device lists. However, we haven't
             # done an initial sync on the device list so we do it now.
             try:
-                if self._is_master:
-                    resync_results = await self.device_handler.device_list_updater.user_device_resync(
+                resync_results = (
+                    await self.device_handler.device_list_updater.user_device_resync(
                         user_id
                     )
-                else:
-                    resync_results = await self._user_device_resync_client(
-                        user_id=user_id
-                    )
+                )
+                if resync_results is None:
+                    raise ValueError("Device resync failed")
 
                 # Add the device keys to the results.
                 user_devices = resync_results["devices"]
@@ -605,6 +600,8 @@ class E2eKeysHandler:
     async def upload_keys_for_user(
         self, user_id: str, device_id: str, keys: JsonDict
     ) -> JsonDict:
+        # This can only be called from the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
 
         time_now = self.clock.time_msec()
 
@@ -732,6 +729,8 @@ class E2eKeysHandler:
             user_id: the user uploading the keys
             keys: the signing keys
         """
+        # This can only be called from the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
 
         # if a master key is uploaded, then check it.  Otherwise, load the
         # stored master key, to check signatures on other keys
@@ -823,6 +822,9 @@ class E2eKeysHandler:
         Raises:
             SynapseError: if the signatures dict is not valid.
         """
+        # This can only be called from the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
+
         failures = {}
 
         # signatures to be stored.  Each item will be a SignatureListItem
@@ -1200,6 +1202,9 @@ class E2eKeysHandler:
             A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
             If the key cannot be retrieved, all values in the tuple will instead be None.
         """
+        # This can only be called from the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
+
         try:
             remote_result = await self.federation.query_user_devices(
                 user.domain, user.to_string()
@@ -1396,11 +1401,14 @@ class SignatureListItem:
 class SigningKeyEduUpdater:
     """Handles incoming signing key updates from federation and updates the DB"""
 
-    def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
         self.federation = hs.get_federation_client()
         self.clock = hs.get_clock()
-        self.e2e_keys_handler = e2e_keys_handler
+
+        device_handler = hs.get_device_handler()
+        assert isinstance(device_handler, DeviceHandler)
+        self._device_handler = device_handler
 
         self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
 
@@ -1445,9 +1453,6 @@ class SigningKeyEduUpdater:
             user_id: the user whose updates we are processing
         """
 
-        device_handler = self.e2e_keys_handler.device_handler
-        device_list_updater = device_handler.device_list_updater
-
         async with self._remote_edu_linearizer.queue(user_id):
             pending_updates = self._pending_updates.pop(user_id, [])
             if not pending_updates:
@@ -1459,13 +1464,11 @@ class SigningKeyEduUpdater:
             logger.info("pending updates: %r", pending_updates)
 
             for master_key, self_signing_key in pending_updates:
-                new_device_ids = (
-                    await device_list_updater.process_cross_signing_key_update(
-                        user_id,
-                        master_key,
-                        self_signing_key,
-                    )
+                new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
+                    user_id,
+                    master_key,
+                    self_signing_key,
                 )
                 device_ids = device_ids + new_device_ids
 
-            await device_handler.notify_device_update(user_id, device_ids)
+            await self._device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ca1c7a1866..6307fa9c5d 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -38,6 +38,7 @@ from synapse.api.errors import (
 )
 from synapse.appservice import ApplicationService
 from synapse.config.server import is_threepid_reserved
+from synapse.handlers.device import DeviceHandler
 from synapse.http.servlet import assert_params_in_dict
 from synapse.replication.http.login import RegisterDeviceReplicationServlet
 from synapse.replication.http.register import (
@@ -841,6 +842,9 @@ class RegistrationHandler:
         refresh_token = None
         refresh_token_id = None
 
+        # This can only run on the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
+
         registered_device_id = await self.device_handler.check_device_registered(
             user_id,
             device_id,
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 73861bbd40..bd9d0bb34b 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -15,6 +15,7 @@ import logging
 from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import Codes, StoreError, SynapseError
+from synapse.handlers.device import DeviceHandler
 from synapse.types import Requester
 
 if TYPE_CHECKING:
@@ -29,7 +30,10 @@ class SetPasswordHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
         self._auth_handler = hs.get_auth_handler()
-        self._device_handler = hs.get_device_handler()
+        # This can only be instantiated on the main process.
+        device_handler = hs.get_device_handler()
+        assert isinstance(device_handler, DeviceHandler)
+        self._device_handler = device_handler
 
     async def set_password(
         self,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 749d7e93b0..e1c0bff1b2 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -37,6 +37,7 @@ from twisted.web.server import Request
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
 from synapse.config.sso import SsoAttributeRequirement
+from synapse.handlers.device import DeviceHandler
 from synapse.handlers.register import init_counters_for_auth_provider
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http import get_request_user_agent
@@ -1035,6 +1036,8 @@ class SsoHandler:
     ) -> None:
         """Revoke any devices and in-flight logins tied to a provider session.
 
+        Can only be called from the main process.
+
         Args:
             auth_provider_id: A unique identifier for this SSO provider, e.g.
                 "oidc" or "saml".
@@ -1042,6 +1045,12 @@ class SsoHandler:
             expected_user_id: The user we're expecting to logout. If set, it will ignore
                 sessions belonging to other users and log an error.
         """
+
+        # It is expected that this is the main process.
+        assert isinstance(
+            self._device_handler, DeviceHandler
+        ), "revoking SSO sessions can only be called on the main process"
+
         # Invalidate any running user-mapping sessions
         to_delete = []
         for session_id, session in self._username_mapping_sessions.items():