summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/databases/main/test_deviceinbox.py164
-rw-r--r--tests/storage/test_client_ips.py206
-rw-r--r--tests/storage/test_rollback_worker.py69
3 files changed, 439 insertions, 0 deletions
diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py
new file mode 100644

index 0000000000..4b67bd15b7 --- /dev/null +++ b/tests/storage/databases/main/test_deviceinbox.py
@@ -0,0 +1,164 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest import admin +from synapse.rest.client import devices + +from tests.unittest import HomeserverTestCase + + +class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): + + servlets = [ + admin.register_servlets, + devices.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.user_id = self.register_user("foo", "pass") + + def test_background_remove_deleted_devices_from_device_inbox(self): + """Test that the background task to delete old device_inboxes works properly.""" + + # create a valid device + self.get_success( + self.store.store_device(self.user_id, "cur_device", "display_name") + ) + + # Add device_inbox to devices + self.get_success( + self.store.db_pool.simple_insert( + "device_inbox", + { + "user_id": self.user_id, + "device_id": "cur_device", + "stream_id": 1, + "message_json": "{}", + }, + ) + ) + self.get_success( + self.store.db_pool.simple_insert( + "device_inbox", + { + "user_id": self.user_id, + "device_id": "old_device", + "stream_id": 2, + "message_json": "{}", + }, + ) + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": "remove_deleted_devices_from_device_inbox", + "progress_json": "{}", + }, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.store.db_pool.updates._all_done = False + + self.wait_for_background_updates() + + # Make sure the background task deleted old device_inbox + res = self.get_success( + self.store.db_pool.simple_select_onecol( + table="device_inbox", + keyvalues={}, + retcol="device_id", + desc="get_device_id_from_device_inbox", + ) + ) + self.assertEqual(1, len(res)) + self.assertEqual(res[0], "cur_device") + + def test_background_remove_hidden_devices_from_device_inbox(self): + """Test that the background task to delete hidden devices + from device_inboxes works properly.""" + + # create a valid device + self.get_success( + self.store.store_device(self.user_id, "cur_device", "display_name") + ) + + # create a hidden device + self.get_success( + self.store.db_pool.simple_insert( + "devices", + values={ + "user_id": self.user_id, + "device_id": "hidden_device", + "display_name": "hidden_display_name", + "hidden": True, + }, + ) + ) + + # Add device_inbox to devices + self.get_success( + self.store.db_pool.simple_insert( + "device_inbox", + { + "user_id": self.user_id, + "device_id": "cur_device", + "stream_id": 1, + "message_json": "{}", + }, + ) + ) + self.get_success( + self.store.db_pool.simple_insert( + "device_inbox", + { + "user_id": self.user_id, + "device_id": "hidden_device", + "stream_id": 2, + "message_json": "{}", + }, + ) + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": "remove_hidden_devices_from_device_inbox", + "progress_json": "{}", + }, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.store.db_pool.updates._all_done = False + + self.wait_for_background_updates() + + # Make sure the background task deleted hidden devices from device_inbox + res = self.get_success( + self.store.db_pool.simple_select_onecol( + table="device_inbox", + keyvalues={}, + retcol="device_id", + desc="get_device_id_from_device_inbox", + ) + ) + self.assertEqual(1, len(res)) + self.assertEqual(res[0], "cur_device") diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 0e4013ebea..c8ac67e35b 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py
@@ -20,6 +20,7 @@ from parameterized import parameterized import synapse.rest.admin from synapse.http.site import XForwardedForRequest from synapse.rest.client import login +from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.types import UserID from tests import unittest @@ -171,6 +172,27 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): if after_persisting: # Trigger the storage loop self.reactor.advance(10) + else: + # Check that the new IP and user agent has not been stored yet + db_result = self.get_success( + self.store.db_pool.simple_select_list( + table="devices", + keyvalues={}, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + ), + ) + self.assertEqual( + db_result, + [ + { + "user_id": user_id, + "device_id": device_id, + "ip": None, + "user_agent": None, + "last_seen": None, + }, + ], + ) result = self.get_success( self.store.get_last_client_ip_by_device(user_id, device_id) @@ -189,6 +211,104 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): }, ) + def test_get_last_client_ip_by_device_combined_data(self): + """Test that `get_last_client_ip_by_device` combines persisted and unpersisted + data together correctly + """ + self.reactor.advance(12345678) + + user_id = "@user:id" + device_id_1 = "MY_DEVICE_1" + device_id_2 = "MY_DEVICE_2" + + # Insert user IPs + self.get_success( + self.store.store_device( + user_id, + device_id_1, + "display name", + ) + ) + self.get_success( + self.store.store_device( + user_id, + device_id_2, + "display name", + ) + ) + self.get_success( + self.store.insert_client_ip( + user_id, "access_token_1", "ip_1", "user_agent_1", device_id_1 + ) + ) + self.get_success( + self.store.insert_client_ip( + user_id, "access_token_2", "ip_2", "user_agent_2", device_id_2 + ) + ) + + # Trigger the storage loop and wait for the rate limiting period to be over + self.reactor.advance(10 + LAST_SEEN_GRANULARITY / 1000) + + # Update the user agent for the second device, without running the storage loop + self.get_success( + self.store.insert_client_ip( + user_id, "access_token_2", "ip_2", "user_agent_3", device_id_2 + ) + ) + + # Check that the new IP and user agent has not been stored yet + db_result = self.get_success( + self.store.db_pool.simple_select_list( + table="devices", + keyvalues={}, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + ), + ) + self.assertCountEqual( + db_result, + [ + { + "user_id": user_id, + "device_id": device_id_1, + "ip": "ip_1", + "user_agent": "user_agent_1", + "last_seen": 12345678000, + }, + { + "user_id": user_id, + "device_id": device_id_2, + "ip": "ip_2", + "user_agent": "user_agent_2", + "last_seen": 12345678000, + }, + ], + ) + + # Check that data from the database and memory are combined together correctly + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, None) + ) + self.assertEqual( + result, + { + (user_id, device_id_1): { + "user_id": user_id, + "device_id": device_id_1, + "ip": "ip_1", + "user_agent": "user_agent_1", + "last_seen": 12345678000, + }, + (user_id, device_id_2): { + "user_id": user_id, + "device_id": device_id_2, + "ip": "ip_2", + "user_agent": "user_agent_3", + "last_seen": 12345688000 + LAST_SEEN_GRANULARITY, + }, + }, + ) + @parameterized.expand([(False,), (True,)]) def test_get_user_ip_and_agents(self, after_persisting: bool): """Test `get_user_ip_and_agents` for persisted and unpersisted data""" @@ -207,6 +327,16 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): if after_persisting: # Trigger the storage loop self.reactor.advance(10) + else: + # Check that the new IP and user agent has not been stored yet + db_result = self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={}, + retcols=("access_token", "ip", "user_agent", "last_seen"), + ), + ) + self.assertEqual(db_result, []) self.assertEqual( self.get_success(self.store.get_user_ip_and_agents(user)), @@ -220,6 +350,82 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ], ) + def test_get_user_ip_and_agents_combined_data(self): + """Test that `get_user_ip_and_agents` combines persisted and unpersisted data + together correctly + """ + self.reactor.advance(12345678) + + user_id = "@user:id" + user = UserID.from_string(user_id) + + # Insert user IPs + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip_1", "user_agent_1", "MY_DEVICE_1" + ) + ) + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip_2", "user_agent_2", "MY_DEVICE_2" + ) + ) + + # Trigger the storage loop and wait for the rate limiting period to be over + self.reactor.advance(10 + LAST_SEEN_GRANULARITY / 1000) + + # Update the user agent for the second device, without running the storage loop + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip_2", "user_agent_3", "MY_DEVICE_2" + ) + ) + + # Check that the new IP and user agent has not been stored yet + db_result = self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={}, + retcols=("access_token", "ip", "user_agent", "last_seen"), + ), + ) + self.assertEqual( + db_result, + [ + { + "access_token": "access_token", + "ip": "ip_1", + "user_agent": "user_agent_1", + "last_seen": 12345678000, + }, + { + "access_token": "access_token", + "ip": "ip_2", + "user_agent": "user_agent_2", + "last_seen": 12345678000, + }, + ], + ) + + # Check that data from the database and memory are combined together correctly + self.assertCountEqual( + self.get_success(self.store.get_user_ip_and_agents(user)), + [ + { + "access_token": "access_token", + "ip": "ip_1", + "user_agent": "user_agent_1", + "last_seen": 12345678000, + }, + { + "access_token": "access_token", + "ip": "ip_2", + "user_agent": "user_agent_3", + "last_seen": 12345688000 + LAST_SEEN_GRANULARITY, + }, + ], + ) + @override_config({"limit_usage_by_mau": False, "max_mau_value": 50}) def test_disabled_monthly_active_user(self): user_id = "@user:server" diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py new file mode 100644
index 0000000000..a6be9a1bb1 --- /dev/null +++ b/tests/storage/test_rollback_worker.py
@@ -0,0 +1,69 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.app.generic_worker import GenericWorkerServer +from synapse.storage.database import LoggingDatabaseConnection +from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database +from synapse.storage.schema import SCHEMA_VERSION + +from tests.unittest import HomeserverTestCase + + +class WorkerSchemaTests(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver( + federation_http_client=None, homeserver_to_use=GenericWorkerServer + ) + return hs + + def default_config(self): + conf = super().default_config() + + # Mark this as a worker app. + conf["worker_app"] = "yes" + + return conf + + def test_rolling_back(self): + """Test that workers can start if the DB is a newer schema version""" + + db_pool = self.hs.get_datastore().db_pool + db_conn = LoggingDatabaseConnection( + db_pool._db_pool.connect(), + db_pool.engine, + "tests", + ) + + cur = db_conn.cursor() + cur.execute("UPDATE schema_version SET version = ?", (SCHEMA_VERSION + 1,)) + + db_conn.commit() + + prepare_database(db_conn, db_pool.engine, self.hs.config) + + def test_not_upgraded(self): + """Test that workers don't start if the DB has an older schema version""" + db_pool = self.hs.get_datastore().db_pool + db_conn = LoggingDatabaseConnection( + db_pool._db_pool.connect(), + db_pool.engine, + "tests", + ) + + cur = db_conn.cursor() + cur.execute("UPDATE schema_version SET version = ?", (SCHEMA_VERSION - 1,)) + + db_conn.commit() + + with self.assertRaises(PrepareDatabaseException): + prepare_database(db_conn, db_pool.engine, self.hs.config)