summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-10-17 08:47:42 -0400
committerGitHub <noreply@github.com>2023-10-17 12:47:42 +0000
commit6ad1f9eac2c5ffc496597acbc5728482441c64c7 (patch)
tree48292d27b2e1e0e1867bed22f580b5b53273c851
parentFix a bug where servers could be marked as up when they were failing (#16506) (diff)
downloadsynapse-6ad1f9eac2c5ffc496597acbc5728482441c64c7.tar.xz
Convert DeviceLastConnectionInfo to attrs. (#16507)
To improve type safety & memory usage.
-rw-r--r--changelog.d/16507.misc1
-rw-r--r--synapse/handlers/device.py23
-rw-r--r--synapse/storage/databases/main/client_ips.py46
-rw-r--r--tests/storage/test_client_ips.py137
4 files changed, 104 insertions, 103 deletions
diff --git a/changelog.d/16507.misc b/changelog.d/16507.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16507.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 50df4f2b06..544bc7c13d 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,17 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Dict,
-    Iterable,
-    List,
-    Mapping,
-    Optional,
-    Set,
-    Tuple,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Set, Tuple
 
 from synapse.api import errors
 from synapse.api.constants import EduTypes, EventTypes
@@ -41,6 +31,7 @@ from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
 )
+from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo
 from synapse.types import (
     JsonDict,
     JsonMapping,
@@ -1008,14 +999,14 @@ class DeviceHandler(DeviceWorkerHandler):
 
 
 def _update_device_from_client_ips(
-    device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
+    device: JsonDict, client_ips: Mapping[Tuple[str, str], DeviceLastConnectionInfo]
 ) -> None:
-    ip = client_ips.get((device["user_id"], device["device_id"]), {})
+    ip = client_ips.get((device["user_id"], device["device_id"]))
     device.update(
         {
-            "last_seen_user_agent": ip.get("user_agent"),
-            "last_seen_ts": ip.get("last_seen"),
-            "last_seen_ip": ip.get("ip"),
+            "last_seen_user_agent": ip.user_agent if ip else None,
+            "last_seen_ts": ip.last_seen if ip else None,
+            "last_seen_ip": ip.ip if ip else None,
         }
     )
 
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 7da47c3dd7..8be1511859 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -15,6 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
 
+import attr
 from typing_extensions import TypedDict
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -42,7 +43,8 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 120 * 1000
 
 
-class DeviceLastConnectionInfo(TypedDict):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class DeviceLastConnectionInfo:
     """Metadata for the last connection seen for a user and device combination"""
 
     # These types must match the columns in the `devices` table
@@ -499,24 +501,29 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
             device_id: If None fetches all devices for the user
 
         Returns:
-            A dictionary mapping a tuple of (user_id, device_id) to dicts, with
-            keys giving the column names from the devices table.
+            A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
         """
 
         keyvalues = {"user_id": user_id}
         if device_id is not None:
             keyvalues["device_id"] = device_id
 
-        res = cast(
-            List[DeviceLastConnectionInfo],
-            await self.db_pool.simple_select_list(
-                table="devices",
-                keyvalues=keyvalues,
-                retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
-            ),
+        res = await self.db_pool.simple_select_list(
+            table="devices",
+            keyvalues=keyvalues,
+            retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
         )
 
-        return {(d["user_id"], d["device_id"]): d for d in res}
+        return {
+            (d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
+                user_id=d["user_id"],
+                device_id=d["device_id"],
+                ip=d["ip"],
+                user_agent=d["user_agent"],
+                last_seen=d["last_seen"],
+            )
+            for d in res
+        }
 
     async def _get_user_ip_and_agents_from_database(
         self, user: UserID, since_ts: int = 0
@@ -683,8 +690,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
             device_id: If None fetches all devices for the user
 
         Returns:
-            A dictionary mapping a tuple of (user_id, device_id) to dicts, with
-            keys giving the column names from the devices table.
+            A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
         """
         ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)
 
@@ -705,13 +711,13 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
                     continue
 
                 if not device_id or did == device_id:
-                    ret[(user_id, did)] = {
-                        "user_id": user_id,
-                        "ip": ip,
-                        "user_agent": user_agent,
-                        "device_id": did,
-                        "last_seen": last_seen,
-                    }
+                    ret[(user_id, did)] = DeviceLastConnectionInfo(
+                        user_id=user_id,
+                        ip=ip,
+                        user_agent=user_agent,
+                        device_id=did,
+                        last_seen=last_seen,
+                    )
         return ret
 
     async def get_user_ip_and_agents(
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 6b9692c486..0c054a598f 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -24,7 +24,10 @@ import synapse.rest.admin
 from synapse.http.site import XForwardedForRequest
 from synapse.rest.client import login
 from synapse.server import HomeServer
-from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
+from synapse.storage.databases.main.client_ips import (
+    LAST_SEEN_GRANULARITY,
+    DeviceLastConnectionInfo,
+)
 from synapse.types import UserID
 from synapse.util import Clock
 
@@ -65,15 +68,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         r = result[(user_id, device_id)]
-        self.assertLessEqual(
-            {
-                "user_id": user_id,
-                "device_id": device_id,
-                "ip": "ip",
-                "user_agent": "user_agent",
-                "last_seen": 12345678000,
-            }.items(),
-            r.items(),
+        self.assertEqual(
+            DeviceLastConnectionInfo(
+                user_id=user_id,
+                device_id=device_id,
+                ip="ip",
+                user_agent="user_agent",
+                last_seen=12345678000,
+            ),
+            r,
         )
 
     def test_insert_new_client_ip_none_device_id(self) -> None:
@@ -201,13 +204,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.assertEqual(
             result,
             {
-                (user_id, device_id): {
-                    "user_id": user_id,
-                    "device_id": device_id,
-                    "ip": "ip",
-                    "user_agent": "user_agent",
-                    "last_seen": 12345678000,
-                },
+                (user_id, device_id): DeviceLastConnectionInfo(
+                    user_id=user_id,
+                    device_id=device_id,
+                    ip="ip",
+                    user_agent="user_agent",
+                    last_seen=12345678000,
+                ),
             },
         )
 
@@ -292,20 +295,20 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.assertEqual(
             result,
             {
-                (user_id, device_id_1): {
-                    "user_id": user_id,
-                    "device_id": device_id_1,
-                    "ip": "ip_1",
-                    "user_agent": "user_agent_1",
-                    "last_seen": 12345678000,
-                },
-                (user_id, device_id_2): {
-                    "user_id": user_id,
-                    "device_id": device_id_2,
-                    "ip": "ip_2",
-                    "user_agent": "user_agent_3",
-                    "last_seen": 12345688000 + LAST_SEEN_GRANULARITY,
-                },
+                (user_id, device_id_1): DeviceLastConnectionInfo(
+                    user_id=user_id,
+                    device_id=device_id_1,
+                    ip="ip_1",
+                    user_agent="user_agent_1",
+                    last_seen=12345678000,
+                ),
+                (user_id, device_id_2): DeviceLastConnectionInfo(
+                    user_id=user_id,
+                    device_id=device_id_2,
+                    ip="ip_2",
+                    user_agent="user_agent_3",
+                    last_seen=12345688000 + LAST_SEEN_GRANULARITY,
+                ),
             },
         )
 
@@ -526,15 +529,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         r = result[(user_id, device_id)]
-        self.assertLessEqual(
-            {
-                "user_id": user_id,
-                "device_id": device_id,
-                "ip": None,
-                "user_agent": None,
-                "last_seen": None,
-            }.items(),
-            r.items(),
+        self.assertEqual(
+            DeviceLastConnectionInfo(
+                user_id=user_id,
+                device_id=device_id,
+                ip=None,
+                user_agent=None,
+                last_seen=None,
+            ),
+            r,
         )
 
         # Register the background update to run again.
@@ -561,15 +564,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         r = result[(user_id, device_id)]
-        self.assertLessEqual(
-            {
-                "user_id": user_id,
-                "device_id": device_id,
-                "ip": "ip",
-                "user_agent": "user_agent",
-                "last_seen": 0,
-            }.items(),
-            r.items(),
+        self.assertEqual(
+            DeviceLastConnectionInfo(
+                user_id=user_id,
+                device_id=device_id,
+                ip="ip",
+                user_agent="user_agent",
+                last_seen=0,
+            ),
+            r,
         )
 
     def test_old_user_ips_pruned(self) -> None:
@@ -640,15 +643,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         r = result2[(user_id, device_id)]
-        self.assertLessEqual(
-            {
-                "user_id": user_id,
-                "device_id": device_id,
-                "ip": "ip",
-                "user_agent": "user_agent",
-                "last_seen": 0,
-            }.items(),
-            r.items(),
+        self.assertEqual(
+            DeviceLastConnectionInfo(
+                user_id=user_id,
+                device_id=device_id,
+                ip="ip",
+                user_agent="user_agent",
+                last_seen=0,
+            ),
+            r,
         )
 
     def test_invalid_user_agents_are_ignored(self) -> None:
@@ -777,13 +780,13 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
             self.store.get_last_client_ip_by_device(self.user_id, device_id)
         )
         r = result[(self.user_id, device_id)]
-        self.assertLessEqual(
-            {
-                "user_id": self.user_id,
-                "device_id": device_id,
-                "ip": expected_ip,
-                "user_agent": "Mozzila pizza",
-                "last_seen": 123456100,
-            }.items(),
-            r.items(),
+        self.assertEqual(
+            DeviceLastConnectionInfo(
+                user_id=self.user_id,
+                device_id=device_id,
+                ip=expected_ip,
+                user_agent="Mozzila pizza",
+                last_seen=123456100,
+            ),
+            r,
         )