summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11179.misc1
-rw-r--r--tests/storage/test_client_ips.py206
2 files changed, 207 insertions, 0 deletions
diff --git a/changelog.d/11179.misc b/changelog.d/11179.misc
new file mode 100644
index 0000000000..aded2e8367
--- /dev/null
+++ b/changelog.d/11179.misc
@@ -0,0 +1 @@
+Add tests to check that `ClientIpStore.get_last_client_ip_by_device` and `get_user_ip_and_agents` combine database and in-memory data correctly.
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 0e4013ebea..c8ac67e35b 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -20,6 +20,7 @@ from parameterized import parameterized
 import synapse.rest.admin
 from synapse.http.site import XForwardedForRequest
 from synapse.rest.client import login
+from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
 from synapse.types import UserID
 
 from tests import unittest
@@ -171,6 +172,27 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         if after_persisting:
             # Trigger the storage loop
             self.reactor.advance(10)
+        else:
+            # Check that the new IP and user agent has not been stored yet
+            db_result = self.get_success(
+                self.store.db_pool.simple_select_list(
+                    table="devices",
+                    keyvalues={},
+                    retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+                ),
+            )
+            self.assertEqual(
+                db_result,
+                [
+                    {
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "ip": None,
+                        "user_agent": None,
+                        "last_seen": None,
+                    },
+                ],
+            )
 
         result = self.get_success(
             self.store.get_last_client_ip_by_device(user_id, device_id)
@@ -189,6 +211,104 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
             },
         )
 
+    def test_get_last_client_ip_by_device_combined_data(self):
+        """Test that `get_last_client_ip_by_device` combines persisted and unpersisted
+        data together correctly
+        """
+        self.reactor.advance(12345678)
+
+        user_id = "@user:id"
+        device_id_1 = "MY_DEVICE_1"
+        device_id_2 = "MY_DEVICE_2"
+
+        # Insert user IPs
+        self.get_success(
+            self.store.store_device(
+                user_id,
+                device_id_1,
+                "display name",
+            )
+        )
+        self.get_success(
+            self.store.store_device(
+                user_id,
+                device_id_2,
+                "display name",
+            )
+        )
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token_1", "ip_1", "user_agent_1", device_id_1
+            )
+        )
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token_2", "ip_2", "user_agent_2", device_id_2
+            )
+        )
+
+        # Trigger the storage loop and wait for the rate limiting period to be over
+        self.reactor.advance(10 + LAST_SEEN_GRANULARITY / 1000)
+
+        # Update the user agent for the second device, without running the storage loop
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token_2", "ip_2", "user_agent_3", device_id_2
+            )
+        )
+
+        # Check that the new IP and user agent has not been stored yet
+        db_result = self.get_success(
+            self.store.db_pool.simple_select_list(
+                table="devices",
+                keyvalues={},
+                retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+            ),
+        )
+        self.assertCountEqual(
+            db_result,
+            [
+                {
+                    "user_id": user_id,
+                    "device_id": device_id_1,
+                    "ip": "ip_1",
+                    "user_agent": "user_agent_1",
+                    "last_seen": 12345678000,
+                },
+                {
+                    "user_id": user_id,
+                    "device_id": device_id_2,
+                    "ip": "ip_2",
+                    "user_agent": "user_agent_2",
+                    "last_seen": 12345678000,
+                },
+            ],
+        )
+
+        # Check that data from the database and memory are combined together correctly
+        result = self.get_success(
+            self.store.get_last_client_ip_by_device(user_id, None)
+        )
+        self.assertEqual(
+            result,
+            {
+                (user_id, device_id_1): {
+                    "user_id": user_id,
+                    "device_id": device_id_1,
+                    "ip": "ip_1",
+                    "user_agent": "user_agent_1",
+                    "last_seen": 12345678000,
+                },
+                (user_id, device_id_2): {
+                    "user_id": user_id,
+                    "device_id": device_id_2,
+                    "ip": "ip_2",
+                    "user_agent": "user_agent_3",
+                    "last_seen": 12345688000 + LAST_SEEN_GRANULARITY,
+                },
+            },
+        )
+
     @parameterized.expand([(False,), (True,)])
     def test_get_user_ip_and_agents(self, after_persisting: bool):
         """Test `get_user_ip_and_agents` for persisted and unpersisted data"""
@@ -207,6 +327,16 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         if after_persisting:
             # Trigger the storage loop
             self.reactor.advance(10)
+        else:
+            # Check that the new IP and user agent has not been stored yet
+            db_result = self.get_success(
+                self.store.db_pool.simple_select_list(
+                    table="user_ips",
+                    keyvalues={},
+                    retcols=("access_token", "ip", "user_agent", "last_seen"),
+                ),
+            )
+            self.assertEqual(db_result, [])
 
         self.assertEqual(
             self.get_success(self.store.get_user_ip_and_agents(user)),
@@ -220,6 +350,82 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
             ],
         )
 
+    def test_get_user_ip_and_agents_combined_data(self):
+        """Test that `get_user_ip_and_agents` combines persisted and unpersisted data
+        together correctly
+        """
+        self.reactor.advance(12345678)
+
+        user_id = "@user:id"
+        user = UserID.from_string(user_id)
+
+        # Insert user IPs
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token", "ip_1", "user_agent_1", "MY_DEVICE_1"
+            )
+        )
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token", "ip_2", "user_agent_2", "MY_DEVICE_2"
+            )
+        )
+
+        # Trigger the storage loop and wait for the rate limiting period to be over
+        self.reactor.advance(10 + LAST_SEEN_GRANULARITY / 1000)
+
+        # Update the user agent for the second device, without running the storage loop
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token", "ip_2", "user_agent_3", "MY_DEVICE_2"
+            )
+        )
+
+        # Check that the new IP and user agent has not been stored yet
+        db_result = self.get_success(
+            self.store.db_pool.simple_select_list(
+                table="user_ips",
+                keyvalues={},
+                retcols=("access_token", "ip", "user_agent", "last_seen"),
+            ),
+        )
+        self.assertEqual(
+            db_result,
+            [
+                {
+                    "access_token": "access_token",
+                    "ip": "ip_1",
+                    "user_agent": "user_agent_1",
+                    "last_seen": 12345678000,
+                },
+                {
+                    "access_token": "access_token",
+                    "ip": "ip_2",
+                    "user_agent": "user_agent_2",
+                    "last_seen": 12345678000,
+                },
+            ],
+        )
+
+        # Check that data from the database and memory are combined together correctly
+        self.assertCountEqual(
+            self.get_success(self.store.get_user_ip_and_agents(user)),
+            [
+                {
+                    "access_token": "access_token",
+                    "ip": "ip_1",
+                    "user_agent": "user_agent_1",
+                    "last_seen": 12345678000,
+                },
+                {
+                    "access_token": "access_token",
+                    "ip": "ip_2",
+                    "user_agent": "user_agent_3",
+                    "last_seen": 12345688000 + LAST_SEEN_GRANULARITY,
+                },
+            ],
+        )
+
     @override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
     def test_disabled_monthly_active_user(self):
         user_id = "@user:server"