summary refs log tree commit diff
path: root/synapse/handlers/device.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/device.py')
-rw-r--r--synapse/handlers/device.py171
1 files changed, 135 insertions, 36 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index b9d9098104..debb1b4f29 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
 # Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api import errors
 from synapse.api.constants import EventTypes
@@ -29,7 +29,10 @@ from synapse.api.errors import (
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import (
+    Collection,
+    JsonDict,
     StreamToken,
+    UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
@@ -41,13 +44,16 @@ from synapse.util.retryutils import NotRetryingDestination
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 MAX_DEVICE_DISPLAY_NAME_LEN = 100
 
 
 class DeviceWorkerHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.hs = hs
@@ -105,7 +111,9 @@ class DeviceWorkerHandler(BaseHandler):
 
     @trace
     @measure_func("device.get_user_ids_changed")
-    async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
+    async def get_user_ids_changed(
+        self, user_id: str, from_token: StreamToken
+    ) -> JsonDict:
         """Get list of users that have had the devices updated, or have newly
         joined a room, that `user_id` may be interested in.
         """
@@ -221,8 +229,8 @@ class DeviceWorkerHandler(BaseHandler):
             possibly_joined = possibly_changed & users_who_share_room
             possibly_left = (possibly_changed | possibly_left) - users_who_share_room
         else:
-            possibly_joined = []
-            possibly_left = []
+            possibly_joined = set()
+            possibly_left = set()
 
         result = {"changed": list(possibly_joined), "left": list(possibly_left)}
 
@@ -230,7 +238,7 @@ class DeviceWorkerHandler(BaseHandler):
 
         return result
 
-    async def on_federation_query_user_devices(self, user_id):
+    async def on_federation_query_user_devices(self, user_id: str) -> JsonDict:
         stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
             user_id
         )
@@ -249,7 +257,7 @@ class DeviceWorkerHandler(BaseHandler):
 
 
 class DeviceHandler(DeviceWorkerHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.federation_sender = hs.get_federation_sender()
@@ -264,7 +272,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
-    def _check_device_name_length(self, name: str):
+    def _check_device_name_length(self, name: Optional[str]):
         """
         Checks whether a device name is longer than the maximum allowed length.
 
@@ -283,8 +291,11 @@ class DeviceHandler(DeviceWorkerHandler):
             )
 
     async def check_device_registered(
-        self, user_id, device_id, initial_device_display_name=None
-    ):
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        initial_device_display_name: Optional[str] = None,
+    ) -> str:
         """
         If the given device has not been registered, register it with the
         supplied display name.
@@ -292,12 +303,11 @@ class DeviceHandler(DeviceWorkerHandler):
         If no device_id is supplied, we make one up.
 
         Args:
-            user_id (str):  @user:id
-            device_id (str | None): device id supplied by client
-            initial_device_display_name (str | None): device display name from
-                 client
+            user_id:  @user:id
+            device_id: device id supplied by client
+            initial_device_display_name: device display name from client
         Returns:
-            str: device id (generated if none was supplied)
+            device id (generated if none was supplied)
         """
 
         self._check_device_name_length(initial_device_display_name)
@@ -316,15 +326,15 @@ class DeviceHandler(DeviceWorkerHandler):
         # times in case of a clash.
         attempts = 0
         while attempts < 5:
-            device_id = stringutils.random_string(10).upper()
+            new_device_id = stringutils.random_string(10).upper()
             new_device = await self.store.store_device(
                 user_id=user_id,
-                device_id=device_id,
+                device_id=new_device_id,
                 initial_device_display_name=initial_device_display_name,
             )
             if new_device:
-                await self.notify_device_update(user_id, [device_id])
-                return device_id
+                await self.notify_device_update(user_id, [new_device_id])
+                return new_device_id
             attempts += 1
 
         raise errors.StoreError(500, "Couldn't generate a device ID.")
@@ -433,7 +443,9 @@ class DeviceHandler(DeviceWorkerHandler):
 
     @trace
     @measure_func("notify_device_update")
-    async def notify_device_update(self, user_id, device_ids):
+    async def notify_device_update(
+        self, user_id: str, device_ids: Collection[str]
+    ) -> None:
         """Notify that a user's device(s) has changed. Pokes the notifier, and
         remote servers if the user is local.
         """
@@ -445,7 +457,7 @@ class DeviceHandler(DeviceWorkerHandler):
             user_id
         )
 
-        hosts = set()
+        hosts = set()  # type: Set[str]
         if self.hs.is_mine_id(user_id):
             hosts.update(get_domain_from_id(u) for u in users_who_share_room)
             hosts.discard(self.server_name)
@@ -497,7 +509,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
         self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
 
-    async def user_left_room(self, user, room_id):
+    async def user_left_room(self, user: UserID, room_id: str) -> None:
         user_id = user.to_string()
         room_ids = await self.store.get_rooms_for_user(user_id)
         if not room_ids:
@@ -505,8 +517,89 @@ class DeviceHandler(DeviceWorkerHandler):
             # receive device updates. Mark this in DB.
             await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
 
+    async def store_dehydrated_device(
+        self,
+        user_id: str,
+        device_data: JsonDict,
+        initial_device_display_name: Optional[str] = None,
+    ) -> str:
+        """Store a dehydrated device for a user.  If the user had a previous
+        dehydrated device, it is removed.
+
+        Args:
+            user_id: the user that we are storing the device for
+            device_data: the dehydrated device information
+            initial_device_display_name: The display name to use for the device
+        Returns:
+            device id of the dehydrated device
+        """
+        device_id = await self.check_device_registered(
+            user_id, None, initial_device_display_name,
+        )
+        old_device_id = await self.store.store_dehydrated_device(
+            user_id, device_id, device_data
+        )
+        if old_device_id is not None:
+            await self.delete_device(user_id, old_device_id)
+        return device_id
+
+    async def get_dehydrated_device(
+        self, user_id: str
+    ) -> Optional[Tuple[str, JsonDict]]:
+        """Retrieve the information for a dehydrated device.
+
+        Args:
+            user_id: the user whose dehydrated device we are looking for
+        Returns:
+            a tuple whose first item is the device ID, and the second item is
+            the dehydrated device information
+        """
+        return await self.store.get_dehydrated_device(user_id)
+
+    async def rehydrate_device(
+        self, user_id: str, access_token: str, device_id: str
+    ) -> dict:
+        """Process a rehydration request from the user.
+
+        Args:
+            user_id: the user who is rehydrating the device
+            access_token: the access token used for the request
+            device_id: the ID of the device that will be rehydrated
+        Returns:
+            a dict containing {"success": True}
+        """
+        success = await self.store.remove_dehydrated_device(user_id, device_id)
+
+        if not success:
+            raise errors.NotFoundError()
+
+        # If the dehydrated device was successfully deleted (the device ID
+        # matched the stored dehydrated device), then modify the access
+        # token to use the dehydrated device's ID and copy the old device
+        # display name to the dehydrated device, and destroy the old device
+        # ID
+        old_device_id = await self.store.set_device_for_access_token(
+            access_token, device_id
+        )
+        old_device = await self.store.get_device(user_id, old_device_id)
+        await self.store.update_device(user_id, device_id, old_device["display_name"])
+        # can't call self.delete_device because that will clobber the
+        # access token so call the storage layer directly
+        await self.store.delete_device(user_id, old_device_id)
+        await self.store.delete_e2e_keys_by_device(
+            user_id=user_id, device_id=old_device_id
+        )
 
-def _update_device_from_client_ips(device, client_ips):
+        # tell everyone that the old device is gone and that the dehydrated
+        # device has a new display name
+        await self.notify_device_update(user_id, [old_device_id, device_id])
+
+        return {"success": True}
+
+
+def _update_device_from_client_ips(
+    device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
+) -> None:
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
     device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
 
@@ -514,7 +607,7 @@ def _update_device_from_client_ips(device, client_ips):
 class DeviceListUpdater:
     "Handles incoming device list updates from federation and updates the DB"
 
-    def __init__(self, hs, device_handler):
+    def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
         self.store = hs.get_datastore()
         self.federation = hs.get_federation_client()
         self.clock = hs.get_clock()
@@ -523,7 +616,9 @@ class DeviceListUpdater:
         self._remote_edu_linearizer = Linearizer(name="remote_device_list")
 
         # user_id -> list of updates waiting to be handled.
-        self._pending_updates = {}
+        self._pending_updates = (
+            {}
+        )  # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
 
         # Recently seen stream ids. We don't bother keeping these in the DB,
         # but they're useful to have them about to reduce the number of spurious
@@ -546,7 +641,9 @@ class DeviceListUpdater:
         )
 
     @trace
-    async def incoming_device_list_update(self, origin, edu_content):
+    async def incoming_device_list_update(
+        self, origin: str, edu_content: JsonDict
+    ) -> None:
         """Called on incoming device list update from federation. Responsible
         for parsing the EDU and adding to pending updates list.
         """
@@ -607,7 +704,7 @@ class DeviceListUpdater:
         await self._handle_device_updates(user_id)
 
     @measure_func("_incoming_device_list_update")
-    async def _handle_device_updates(self, user_id):
+    async def _handle_device_updates(self, user_id: str) -> None:
         "Actually handle pending updates."
 
         with (await self._remote_edu_linearizer.queue(user_id)):
@@ -655,7 +752,9 @@ class DeviceListUpdater:
                     stream_id for _, stream_id, _, _ in pending_updates
                 )
 
-    async def _need_to_do_resync(self, user_id, updates):
+    async def _need_to_do_resync(
+        self, user_id: str, updates: Iterable[Tuple[str, str, Iterable[str], JsonDict]]
+    ) -> bool:
         """Given a list of updates for a user figure out if we need to do a full
         resync, or whether we have enough data that we can just apply the delta.
         """
@@ -686,7 +785,7 @@ class DeviceListUpdater:
         return False
 
     @trace
-    async def _maybe_retry_device_resync(self):
+    async def _maybe_retry_device_resync(self) -> None:
         """Retry to resync device lists that are out of sync, except if another retry is
         in progress.
         """
@@ -729,7 +828,7 @@ class DeviceListUpdater:
 
     async def user_device_resync(
         self, user_id: str, mark_failed_as_stale: bool = True
-    ) -> Optional[dict]:
+    ) -> Optional[JsonDict]:
         """Fetches all devices for a user and updates the device cache with them.
 
         Args:
@@ -753,7 +852,7 @@ class DeviceListUpdater:
                 # it later.
                 await self.store.mark_remote_user_device_cache_as_stale(user_id)
 
-            return
+            return None
         except (RequestSendFailed, HttpResponseException) as e:
             logger.warning(
                 "Failed to handle device list update for %s: %s", user_id, e,
@@ -770,12 +869,12 @@ class DeviceListUpdater:
             # next time we get a device list update for this user_id.
             # This makes it more likely that the device lists will
             # eventually become consistent.
-            return
+            return None
         except FederationDeniedError as e:
             set_tag("error", True)
             log_kv({"reason": "FederationDeniedError"})
             logger.info(e)
-            return
+            return None
         except Exception as e:
             set_tag("error", True)
             log_kv(
@@ -788,7 +887,7 @@ class DeviceListUpdater:
                 # it later.
                 await self.store.mark_remote_user_device_cache_as_stale(user_id)
 
-            return
+            return None
         log_kv({"result": result})
         stream_id = result["stream_id"]
         devices = result["devices"]
@@ -849,7 +948,7 @@ class DeviceListUpdater:
         user_id: str,
         master_key: Optional[Dict[str, Any]],
         self_signing_key: Optional[Dict[str, Any]],
-    ) -> list:
+    ) -> List[str]:
         """Process the given new master and self-signing key for the given remote user.
 
         Args: