summary refs log tree commit diff
path: root/synapse/storage/databases/main/client_ips.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/client_ips.py')
-rw-r--r--synapse/storage/databases/main/client_ips.py46
1 files changed, 26 insertions, 20 deletions
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 7da47c3dd7..8be1511859 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -15,6 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
 
+import attr
 from typing_extensions import TypedDict
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -42,7 +43,8 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 120 * 1000
 
 
-class DeviceLastConnectionInfo(TypedDict):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class DeviceLastConnectionInfo:
     """Metadata for the last connection seen for a user and device combination"""
 
     # These types must match the columns in the `devices` table
@@ -499,24 +501,29 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
             device_id: If None fetches all devices for the user
 
         Returns:
-            A dictionary mapping a tuple of (user_id, device_id) to dicts, with
-            keys giving the column names from the devices table.
+            A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
         """
 
         keyvalues = {"user_id": user_id}
         if device_id is not None:
             keyvalues["device_id"] = device_id
 
-        res = cast(
-            List[DeviceLastConnectionInfo],
-            await self.db_pool.simple_select_list(
-                table="devices",
-                keyvalues=keyvalues,
-                retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
-            ),
+        res = await self.db_pool.simple_select_list(
+            table="devices",
+            keyvalues=keyvalues,
+            retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
         )
 
-        return {(d["user_id"], d["device_id"]): d for d in res}
+        return {
+            (d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
+                user_id=d["user_id"],
+                device_id=d["device_id"],
+                ip=d["ip"],
+                user_agent=d["user_agent"],
+                last_seen=d["last_seen"],
+            )
+            for d in res
+        }
 
     async def _get_user_ip_and_agents_from_database(
         self, user: UserID, since_ts: int = 0
@@ -683,8 +690,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
             device_id: If None fetches all devices for the user
 
         Returns:
-            A dictionary mapping a tuple of (user_id, device_id) to dicts, with
-            keys giving the column names from the devices table.
+            A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
         """
         ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)
 
@@ -705,13 +711,13 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
                     continue
 
                 if not device_id or did == device_id:
-                    ret[(user_id, did)] = {
-                        "user_id": user_id,
-                        "ip": ip,
-                        "user_agent": user_agent,
-                        "device_id": did,
-                        "last_seen": last_seen,
-                    }
+                    ret[(user_id, did)] = DeviceLastConnectionInfo(
+                        user_id=user_id,
+                        ip=ip,
+                        user_agent=user_agent,
+                        device_id=did,
+                        last_seen=last_seen,
+                    )
         return ret
 
     async def get_user_ip_and_agents(