summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py50
1 files changed, 49 insertions, 1 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9ccc66e589..d5a4a661cd 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return {d["device_id"]: d for d in devices}
 
+    async def get_devices_by_auth_provider_session_id(
+        self, auth_provider_id: str, auth_provider_session_id: str
+    ) -> List[Dict[str, Any]]:
+        """Retrieve the list of devices associated with a SSO IdP session ID.
+
+        Args:
+            auth_provider_id: The SSO IdP ID as defined in the server config
+            auth_provider_session_id: The session ID within the IdP
+        Returns:
+            A list of dicts containing the device_id and the user_id of each device
+        """
+        return await self.db_pool.simple_select_list(
+            table="device_auth_providers",
+            keyvalues={
+                "auth_provider_id": auth_provider_id,
+                "auth_provider_session_id": auth_provider_session_id,
+            },
+            retcols=("user_id", "device_id"),
+            desc="get_devices_by_auth_provider_session_id",
+        )
+
     @trace
     async def get_device_updates_by_remote(
         self, destination: str, from_stream_id: int, limit: int
@@ -1070,7 +1091,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     async def store_device(
-        self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
+        self,
+        user_id: str,
+        device_id: str,
+        initial_device_display_name: Optional[str],
+        auth_provider_id: Optional[str] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> bool:
         """Ensure the given device is known; add it to the store if not
 
@@ -1079,6 +1105,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             device_id: id of device
             initial_device_display_name: initial displayname of the device.
                 Ignored if device exists.
+            auth_provider_id: The SSO IdP the user used, if any.
+            auth_provider_session_id: The session ID (sid) got from a OIDC login.
 
         Returns:
             Whether the device was inserted or an existing device existed with that ID.
@@ -1115,6 +1143,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 if hidden:
                     raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
 
+            if auth_provider_id and auth_provider_session_id:
+                await self.db_pool.simple_insert(
+                    "device_auth_providers",
+                    values={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "auth_provider_id": auth_provider_id,
+                        "auth_provider_session_id": auth_provider_session_id,
+                    },
+                    desc="store_device_auth_provider",
+                )
+
             self.device_id_exists_cache.set(key, True)
             return inserted
         except StoreError:
@@ -1168,6 +1208,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 keyvalues={"user_id": user_id},
             )
 
+            self.db_pool.simple_delete_many_txn(
+                txn,
+                table="device_auth_providers",
+                column="device_id",
+                values=device_ids,
+                keyvalues={"user_id": user_id},
+            )
+
         await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
         for device_id in device_ids:
             self.device_id_exists_cache.invalidate((user_id, device_id))