diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py
new file mode 100644
index 0000000000..4b67bd15b7
--- /dev/null
+++ b/tests/storage/databases/main/test_deviceinbox.py
@@ -0,0 +1,164 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest import admin
+from synapse.rest.client import devices
+
+from tests.unittest import HomeserverTestCase
+
+
+class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.user_id = self.register_user("foo", "pass")
+
+ def test_background_remove_deleted_devices_from_device_inbox(self):
+ """Test that the background task to delete old device_inboxes works properly."""
+
+ # create a valid device
+ self.get_success(
+ self.store.store_device(self.user_id, "cur_device", "display_name")
+ )
+
+ # Add device_inbox to devices
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "device_inbox",
+ {
+ "user_id": self.user_id,
+ "device_id": "cur_device",
+ "stream_id": 1,
+ "message_json": "{}",
+ },
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "device_inbox",
+ {
+ "user_id": self.user_id,
+ "device_id": "old_device",
+ "stream_id": 2,
+ "message_json": "{}",
+ },
+ )
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": "remove_deleted_devices_from_device_inbox",
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ # ... and tell the DataStore that it hasn't finished all updates yet
+ self.store.db_pool.updates._all_done = False
+
+ self.wait_for_background_updates()
+
+ # Make sure the background task deleted old device_inbox
+ res = self.get_success(
+ self.store.db_pool.simple_select_onecol(
+ table="device_inbox",
+ keyvalues={},
+ retcol="device_id",
+ desc="get_device_id_from_device_inbox",
+ )
+ )
+ self.assertEqual(1, len(res))
+ self.assertEqual(res[0], "cur_device")
+
+ def test_background_remove_hidden_devices_from_device_inbox(self):
+ """Test that the background task to delete hidden devices
+ from device_inboxes works properly."""
+
+ # create a valid device
+ self.get_success(
+ self.store.store_device(self.user_id, "cur_device", "display_name")
+ )
+
+ # create a hidden device
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "devices",
+ values={
+ "user_id": self.user_id,
+ "device_id": "hidden_device",
+ "display_name": "hidden_display_name",
+ "hidden": True,
+ },
+ )
+ )
+
+ # Add device_inbox to devices
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "device_inbox",
+ {
+ "user_id": self.user_id,
+ "device_id": "cur_device",
+ "stream_id": 1,
+ "message_json": "{}",
+ },
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "device_inbox",
+ {
+ "user_id": self.user_id,
+ "device_id": "hidden_device",
+ "stream_id": 2,
+ "message_json": "{}",
+ },
+ )
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": "remove_hidden_devices_from_device_inbox",
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ # ... and tell the DataStore that it hasn't finished all updates yet
+ self.store.db_pool.updates._all_done = False
+
+ self.wait_for_background_updates()
+
+ # Make sure the background task deleted hidden devices from device_inbox
+ res = self.get_success(
+ self.store.db_pool.simple_select_onecol(
+ table="device_inbox",
+ keyvalues={},
+ retcol="device_id",
+ desc="get_device_id_from_device_inbox",
+ )
+ )
+ self.assertEqual(1, len(res))
+ self.assertEqual(res[0], "cur_device")
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 0e4013ebea..c8ac67e35b 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -20,6 +20,7 @@ from parameterized import parameterized
import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
from synapse.rest.client import login
+from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.types import UserID
from tests import unittest
@@ -171,6 +172,27 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
if after_persisting:
# Trigger the storage loop
self.reactor.advance(10)
+ else:
+ # Check that the new IP and user agent has not been stored yet
+ db_result = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="devices",
+ keyvalues={},
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ ),
+ )
+ self.assertEqual(
+ db_result,
+ [
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "ip": None,
+ "user_agent": None,
+ "last_seen": None,
+ },
+ ],
+ )
result = self.get_success(
self.store.get_last_client_ip_by_device(user_id, device_id)
@@ -189,6 +211,104 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
},
)
+ def test_get_last_client_ip_by_device_combined_data(self):
+ """Test that `get_last_client_ip_by_device` combines persisted and unpersisted
+ data together correctly
+ """
+ self.reactor.advance(12345678)
+
+ user_id = "@user:id"
+ device_id_1 = "MY_DEVICE_1"
+ device_id_2 = "MY_DEVICE_2"
+
+ # Insert user IPs
+ self.get_success(
+ self.store.store_device(
+ user_id,
+ device_id_1,
+ "display name",
+ )
+ )
+ self.get_success(
+ self.store.store_device(
+ user_id,
+ device_id_2,
+ "display name",
+ )
+ )
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token_1", "ip_1", "user_agent_1", device_id_1
+ )
+ )
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token_2", "ip_2", "user_agent_2", device_id_2
+ )
+ )
+
+ # Trigger the storage loop and wait for the rate limiting period to be over
+ self.reactor.advance(10 + LAST_SEEN_GRANULARITY / 1000)
+
+ # Update the user agent for the second device, without running the storage loop
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token_2", "ip_2", "user_agent_3", device_id_2
+ )
+ )
+
+ # Check that the new IP and user agent has not been stored yet
+ db_result = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="devices",
+ keyvalues={},
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ ),
+ )
+ self.assertCountEqual(
+ db_result,
+ [
+ {
+ "user_id": user_id,
+ "device_id": device_id_1,
+ "ip": "ip_1",
+ "user_agent": "user_agent_1",
+ "last_seen": 12345678000,
+ },
+ {
+ "user_id": user_id,
+ "device_id": device_id_2,
+ "ip": "ip_2",
+ "user_agent": "user_agent_2",
+ "last_seen": 12345678000,
+ },
+ ],
+ )
+
+ # Check that data from the database and memory are combined together correctly
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(user_id, None)
+ )
+ self.assertEqual(
+ result,
+ {
+ (user_id, device_id_1): {
+ "user_id": user_id,
+ "device_id": device_id_1,
+ "ip": "ip_1",
+ "user_agent": "user_agent_1",
+ "last_seen": 12345678000,
+ },
+ (user_id, device_id_2): {
+ "user_id": user_id,
+ "device_id": device_id_2,
+ "ip": "ip_2",
+ "user_agent": "user_agent_3",
+ "last_seen": 12345688000 + LAST_SEEN_GRANULARITY,
+ },
+ },
+ )
+
@parameterized.expand([(False,), (True,)])
def test_get_user_ip_and_agents(self, after_persisting: bool):
"""Test `get_user_ip_and_agents` for persisted and unpersisted data"""
@@ -207,6 +327,16 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
if after_persisting:
# Trigger the storage loop
self.reactor.advance(10)
+ else:
+ # Check that the new IP and user agent has not been stored yet
+ db_result = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={},
+ retcols=("access_token", "ip", "user_agent", "last_seen"),
+ ),
+ )
+ self.assertEqual(db_result, [])
self.assertEqual(
self.get_success(self.store.get_user_ip_and_agents(user)),
@@ -220,6 +350,82 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
],
)
+ def test_get_user_ip_and_agents_combined_data(self):
+ """Test that `get_user_ip_and_agents` combines persisted and unpersisted data
+ together correctly
+ """
+ self.reactor.advance(12345678)
+
+ user_id = "@user:id"
+ user = UserID.from_string(user_id)
+
+ # Insert user IPs
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip_1", "user_agent_1", "MY_DEVICE_1"
+ )
+ )
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip_2", "user_agent_2", "MY_DEVICE_2"
+ )
+ )
+
+ # Trigger the storage loop and wait for the rate limiting period to be over
+ self.reactor.advance(10 + LAST_SEEN_GRANULARITY / 1000)
+
+ # Update the user agent for the second device, without running the storage loop
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip_2", "user_agent_3", "MY_DEVICE_2"
+ )
+ )
+
+ # Check that the new IP and user agent has not been stored yet
+ db_result = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={},
+ retcols=("access_token", "ip", "user_agent", "last_seen"),
+ ),
+ )
+ self.assertEqual(
+ db_result,
+ [
+ {
+ "access_token": "access_token",
+ "ip": "ip_1",
+ "user_agent": "user_agent_1",
+ "last_seen": 12345678000,
+ },
+ {
+ "access_token": "access_token",
+ "ip": "ip_2",
+ "user_agent": "user_agent_2",
+ "last_seen": 12345678000,
+ },
+ ],
+ )
+
+ # Check that data from the database and memory are combined together correctly
+ self.assertCountEqual(
+ self.get_success(self.store.get_user_ip_and_agents(user)),
+ [
+ {
+ "access_token": "access_token",
+ "ip": "ip_1",
+ "user_agent": "user_agent_1",
+ "last_seen": 12345678000,
+ },
+ {
+ "access_token": "access_token",
+ "ip": "ip_2",
+ "user_agent": "user_agent_3",
+ "last_seen": 12345688000 + LAST_SEEN_GRANULARITY,
+ },
+ ],
+ )
+
@override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
def test_disabled_monthly_active_user(self):
user_id = "@user:server"
diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py
new file mode 100644
index 0000000000..a6be9a1bb1
--- /dev/null
+++ b/tests/storage/test_rollback_worker.py
@@ -0,0 +1,69 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.app.generic_worker import GenericWorkerServer
+from synapse.storage.database import LoggingDatabaseConnection
+from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database
+from synapse.storage.schema import SCHEMA_VERSION
+
+from tests.unittest import HomeserverTestCase
+
+
+class WorkerSchemaTests(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver(
+ federation_http_client=None, homeserver_to_use=GenericWorkerServer
+ )
+ return hs
+
+ def default_config(self):
+ conf = super().default_config()
+
+ # Mark this as a worker app.
+ conf["worker_app"] = "yes"
+
+ return conf
+
+ def test_rolling_back(self):
+ """Test that workers can start if the DB is a newer schema version"""
+
+ db_pool = self.hs.get_datastore().db_pool
+ db_conn = LoggingDatabaseConnection(
+ db_pool._db_pool.connect(),
+ db_pool.engine,
+ "tests",
+ )
+
+ cur = db_conn.cursor()
+ cur.execute("UPDATE schema_version SET version = ?", (SCHEMA_VERSION + 1,))
+
+ db_conn.commit()
+
+ prepare_database(db_conn, db_pool.engine, self.hs.config)
+
+ def test_not_upgraded(self):
+ """Test that workers don't start if the DB has an older schema version"""
+ db_pool = self.hs.get_datastore().db_pool
+ db_conn = LoggingDatabaseConnection(
+ db_pool._db_pool.connect(),
+ db_pool.engine,
+ "tests",
+ )
+
+ cur = db_conn.cursor()
+ cur.execute("UPDATE schema_version SET version = ?", (SCHEMA_VERSION - 1,))
+
+ db_conn.commit()
+
+ with self.assertRaises(PrepareDatabaseException):
+ prepare_database(db_conn, db_pool.engine, self.hs.config)
|