summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test__base.py8
-rw-r--r--tests/storage/test_appservice.py10
-rw-r--r--tests/storage/test_background_update.py24
-rw-r--r--tests/storage/test_base.py3
-rw-r--r--tests/storage/test_cleanup_extrems.py4
-rw-r--r--tests/storage/test_client_ips.py13
-rw-r--r--tests/storage/test_database.py52
-rw-r--r--tests/storage/test_devices.py45
-rw-r--r--tests/storage/test_event_metrics.py2
-rw-r--r--tests/storage/test_event_push_actions.py3
-rw-r--r--tests/storage/test_id_generators.py184
-rw-r--r--tests/storage/test_main.py48
-rw-r--r--tests/storage/test_monthly_active_users.py110
-rw-r--r--tests/storage/test_room.py11
-rw-r--r--tests/storage/test_roommember.py50
15 files changed, 440 insertions, 127 deletions
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py

index e37260a820..5a50e4fdd4 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py
@@ -25,8 +25,8 @@ from synapse.util.caches.descriptors import Cache, cached from tests import unittest -class CacheTestCase(unittest.TestCase): - def setUp(self): +class CacheTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): self.cache = Cache("test") def test_empty(self): @@ -96,7 +96,7 @@ class CacheTestCase(unittest.TestCase): cache.get(3) -class CacheDecoratorTestCase(unittest.TestCase): +class CacheDecoratorTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def test_passthrough(self): class A(object): @@ -239,7 +239,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount2 = [0] class A(object): - @cached(max_entries=4) # HACK: This makes it 2 due to cache factor + @cached(max_entries=2) def func(self, key): callcount[0] += 1 return key diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 31710949a8..ef296e7dab 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py
@@ -43,7 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_token = "token1" @@ -110,7 +110,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_list = [ @@ -422,7 +422,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] database = hs.get_datastores().databases[0] @@ -440,7 +440,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: @@ -464,7 +464,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index ae14fb407d..940b166129 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py
@@ -11,7 +11,9 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater # the base test class should have run the real bg updates for us - self.assertTrue(self.updates.has_completed_background_updates()) + self.assertTrue( + self.get_success(self.updates.has_completed_background_updates()) + ) self.update_handler = Mock() self.updates.register_background_update_handler( @@ -25,12 +27,20 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # the target runtime for each bg update target_background_update_duration_ms = 50000 + store = self.hs.get_datastore() + self.get_success( + store.db.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + # first step: make a bit of progress @defer.inlineCallbacks def update(progress, count): yield self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield self.hs.get_datastore().db.runInteraction( + yield store.db.runInteraction( "update_progress", self.updates._background_update_progress_txn, "test_update", @@ -39,10 +49,6 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): return count self.update_handler.side_effect = update - - self.get_success( - self.updates.start_background_update("test_update", {"my_key": 1}) - ) self.update_handler.reset_mock() res = self.get_success( self.updates.do_next_background_update( @@ -50,7 +56,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): ), by=0.1, ) - self.assertIsNotNone(res) + self.assertFalse(res) # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( @@ -73,7 +79,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): result = self.get_success( self.updates.do_next_background_update(target_background_update_duration_ms) ) - self.assertIsNotNone(result) + self.assertFalse(result) self.update_handler.assert_called_once() # third step: we don't expect to be called any more @@ -81,5 +87,5 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): result = self.get_success( self.updates.do_next_background_update(target_background_update_duration_ms) ) - self.assertIsNone(result) + self.assertTrue(result) self.assertFalse(self.update_handler.called) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index cdee0a9e60..278961c331 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py
@@ -51,7 +51,8 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config._disable_native_upserts = True - config.event_cache_size = 1 + config.caches = Mock() + config.caches.event_cache_size = 1 hs = TestHomeServer("test", config=config) sqlite_config = {"name": "sqlite3"} diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 0e04b2cf92..43425c969a 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py
@@ -39,7 +39,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] def run_background_update(self): @@ -261,7 +261,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() homeserver.config.user_consent_version = self.CONSENT_VERSION diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index bf674dd184..3b483bc7f0 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py
@@ -23,6 +23,7 @@ from synapse.http.site import XForwardedForRequest from synapse.rest.client.v1 import login from tests import unittest +from tests.unittest import override_config class ClientIpStoreTestCase(unittest.HomeserverTestCase): @@ -137,9 +138,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"limit_usage_by_mau": False, "max_mau_value": 50}) def test_disabled_monthly_active_user(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.max_mau_value = 50 user_id = "@user:server" self.get_success( self.store.insert_client_ip( @@ -149,9 +149,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) + @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) def test_adding_monthly_active_user_when_full(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 50 lots_of_users = 100 user_id = "@user:server" @@ -166,9 +165,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) + @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) def test_adding_monthly_active_user_when_space(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 50 user_id = "@user:server" active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) @@ -184,9 +182,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) + @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) def test_updating_monthly_active_user_when_space(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 50 user_id = "@user:server" self.get_success(self.store.register_user(user_id=user_id, password_hash=None)) diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py new file mode 100644
index 0000000000..5a77c84962 --- /dev/null +++ b/tests/storage/test_database.py
@@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.database import make_tuple_comparison_clause +from synapse.storage.engines import BaseDatabaseEngine + +from tests import unittest + + +def _stub_db_engine(**kwargs) -> BaseDatabaseEngine: + # returns a DatabaseEngine, circumventing the abc mechanism + # any kwargs are set as attributes on the class before instantiating it + t = type( + "TestBaseDatabaseEngine", + (BaseDatabaseEngine,), + dict(BaseDatabaseEngine.__dict__), + ) + # defeat the abc mechanism + t.__abstractmethods__ = set() + for k, v in kwargs.items(): + setattr(t, k, v) + return t(None, None) + + +class TupleComparisonClauseTestCase(unittest.TestCase): + def test_native_tuple_comparison(self): + db_engine = _stub_db_engine(supports_tuple_comparison=True) + clause, args = make_tuple_comparison_clause(db_engine, [("a", 1), ("b", 2)]) + self.assertEqual(clause, "(a,b) > (?,?)") + self.assertEqual(args, [1, 2]) + + def test_emulated_tuple_comparison(self): + db_engine = _stub_db_engine(supports_tuple_comparison=False) + clause, args = make_tuple_comparison_clause( + db_engine, [("a", 1), ("b", 2), ("c", 3)] + ) + self.assertEqual( + clause, "(a >= ? AND (a > ? OR (b >= ? AND (b > ? OR c > ?))))" + ) + self.assertEqual(args, [1, 1, 2, 2, 3]) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 6f8d990959..c2539b353a 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py
@@ -88,51 +88,6 @@ class DeviceStoreTestCase(tests.unittest.TestCase): # Check original device_ids are contained within these updates self._check_devices_in_updates(device_ids, device_updates) - @defer.inlineCallbacks - def test_get_device_updates_by_remote_limited(self): - # Test breaking the update limit in 1, 101, and 1 device_id segments - - # first add one device - device_ids1 = ["device_id0"] - yield self.store.add_device_change_to_streams( - "user_id", device_ids1, ["someotherhost"] - ) - - # then add 101 - device_ids2 = ["device_id" + str(i + 1) for i in range(101)] - yield self.store.add_device_change_to_streams( - "user_id", device_ids2, ["someotherhost"] - ) - - # then one more - device_ids3 = ["newdevice"] - yield self.store.add_device_change_to_streams( - "user_id", device_ids3, ["someotherhost"] - ) - - # - # now read them back. - # - - # first we should get a single update - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "someotherhost", -1, limit=100 - ) - self._check_devices_in_updates(device_ids1, device_updates) - - # Then we should get an empty list back as the 101 devices broke the limit - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "someotherhost", now_stream_id, limit=100 - ) - self.assertEqual(len(device_updates), 0) - - # The 101 devices should've been cleared, so we should now just get one device - # update - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "someotherhost", now_stream_id, limit=100 - ) - self._check_devices_in_updates(device_ids3, device_updates) - def _check_devices_in_updates(self, expected_device_ids, device_updates): """Check that an specific device ids exist in a list of device update EDUs""" self.assertEqual(len(device_updates), len(expected_device_ids)) diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index a7b7fd36d3..a7b85004e5 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py
@@ -33,7 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): events = [(3, 2), (6, 2), (4, 6)] for event_count, extrems in events: - info = self.get_success(room_creator.create_room(requester, {})) + info, _ = self.get_success(room_creator.create_room(requester, {})) room_id = info["room_id"] last_event = None diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index d4bcf1821e..b45bc9c115 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py
@@ -35,6 +35,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): def setUp(self): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() + self.persist_events_store = hs.get_datastores().persist_events @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_http(self): @@ -76,7 +77,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) yield self.store.db.runInteraction( "", - self.store._set_push_actions_for_event_and_users_txn, + self.persist_events_store._set_push_actions_for_event_and_users_txn, [(event, None)], [(event, None)], ) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py new file mode 100644
index 0000000000..55e9ecf264 --- /dev/null +++ b/tests/storage/test_id_generators.py
@@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.storage.database import Database +from synapse.storage.util.id_generators import MultiWriterIdGenerator + +from tests.unittest import HomeserverTestCase +from tests.utils import USE_POSTGRES_FOR_TESTS + + +class MultiWriterIdGeneratorTestCase(HomeserverTestCase): + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.db = self.store.db # type: Database + + self.get_success(self.db.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn): + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create(conn): + return MultiWriterIdGenerator( + conn, + self.db, + instance_name=instance_name, + table="foobar", + instance_column="instance_name", + id_column="stream_id", + sequence_name="foobar_seq", + ) + + return self.get_success(self.db.runWithConnection(_create)) + + def _insert_rows(self, instance_name: str, number: int): + def _insert(txn): + for _ in range(number): + txn.execute( + "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", + (instance_name,), + ) + + self.get_success(self.db.runInteraction("test_single_instance", _insert)) + + def test_empty(self): + """Test an ID generator against an empty database gives sensible + current positions. + """ + + id_gen = self._create_id_generator() + + # The table is empty so we expect an empty map for positions + self.assertEqual(id_gen.get_positions(), {}) + + def test_single_instance(self): + """Test that reads and writes from a single process are handled + correctly. + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + async def _get_next_async(): + with await id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 8) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + self.get_success(_get_next_async()) + + self.assertEqual(id_gen.get_positions(), {"master": 8}) + self.assertEqual(id_gen.get_current_token("master"), 8) + + def test_multi_instance(self): + """Test that reads and writes from multiple processes are handled + correctly. + """ + self._insert_rows("first", 3) + self._insert_rows("second", 4) + + first_id_gen = self._create_id_generator("first") + second_id_gen = self._create_id_generator("second") + + self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(first_id_gen.get_current_token("first"), 3) + self.assertEqual(first_id_gen.get_current_token("second"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + async def _get_next_async(): + with await first_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 8) + + self.assertEqual( + first_id_gen.get_positions(), {"first": 3, "second": 7} + ) + + self.get_success(_get_next_async()) + + self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7}) + + # However the ID gen on the second instance won't have seen the update + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) + + # ... but calling `get_next` on the second instance should give a unique + # stream ID + + async def _get_next_async(): + with await second_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 9) + + self.assertEqual( + second_id_gen.get_positions(), {"first": 3, "second": 7} + ) + + self.get_success(_get_next_async()) + + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9}) + + # If the second ID gen gets told about the first, it correctly updates + second_id_gen.advance("first", 8) + self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) + + def test_get_next_txn(self): + """Test that the `get_next_txn` function works correctly. + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + def _get_next_txn(txn): + stream_id = id_gen.get_next_txn(txn) + self.assertEqual(stream_id, 8) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + self.get_success(self.db.runInteraction("test", _get_next_txn)) + + self.assertEqual(id_gen.get_positions(), {"master": 8}) + self.assertEqual(id_gen.get_current_token("master"), 8) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py new file mode 100644
index 0000000000..0155ffd04e --- /dev/null +++ b/tests/storage/test_main.py
@@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Awesome Technologies Innovationslabor GmbH +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from synapse.types import UserID + +from tests import unittest +from tests.utils import setup_test_homeserver + + +class DataStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + hs = yield setup_test_homeserver(self.addCleanup) + + self.store = hs.get_datastore() + + self.user = UserID.from_string("@abcde:test") + self.displayname = "Frank" + + @defer.inlineCallbacks + def test_get_users_paginate(self): + yield self.store.register_user(self.user.to_string(), "pass") + yield self.store.create_profile(self.user.localpart) + yield self.store.set_profile_displayname( + self.user.localpart, self.displayname, 1 + ) + + users, total = yield self.store.get_users_paginate( + 0, 10, name="bc", guests=False + ) + + self.assertEquals(1, total) + self.assertEquals(self.displayname, users.pop()["displayname"]) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index bc53bf0951..447fcb3a1c 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py
@@ -19,94 +19,106 @@ from twisted.internet import defer from synapse.api.constants import UserTypes from tests import unittest +from tests.unittest import default_config, override_config FORTY_DAYS = 40 * 24 * 60 * 60 +def gen_3pids(count): + """Generate `count` threepids as a list.""" + return [ + {"medium": "email", "address": "user%i@matrix.org" % i} for i in range(count) + ] + + class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def default_config(self): + config = default_config("test") + + config.update({"limit_usage_by_mau": True, "max_mau_value": 50}) - hs = self.setup_test_homeserver() - self.store = hs.get_datastore() - hs.config.limit_usage_by_mau = True - hs.config.max_mau_value = 50 + # apply any additional config which was specified via the override_config + # decorator. + if self._extra_config is not None: + config.update(self._extra_config) + return config + + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() # Advance the clock a bit reactor.advance(FORTY_DAYS) - return hs - + @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)}) def test_initialise_reserved_users(self): - self.hs.config.max_mau_value = 5 + threepids = self.hs.config.mau_limits_reserved_threepids + + # register three users, of which two have reserved 3pids, and a third + # which is a support user. user1 = "@user1:server" - user1_email = "user1@matrix.org" + user1_email = threepids[0]["address"] user2 = "@user2:server" - user2_email = "user2@matrix.org" + user2_email = threepids[1]["address"] user3 = "@user3:server" - user3_email = "user3@matrix.org" - threepids = [ - {"medium": "email", "address": user1_email}, - {"medium": "email", "address": user2_email}, - {"medium": "email", "address": user3_email}, - ] - self.hs.config.mau_limits_reserved_threepids = threepids - # -1 because user3 is a support user and does not count - user_num = len(threepids) - 1 - - self.store.register_user(user_id=user1, password_hash=None) - self.store.register_user(user_id=user2, password_hash=None) - self.store.register_user( - user_id=user3, password_hash=None, user_type=UserTypes.SUPPORT - ) + self.store.register_user(user_id=user1) + self.store.register_user(user_id=user2) + self.store.register_user(user_id=user3, user_type=UserTypes.SUPPORT) self.pump() now = int(self.hs.get_clock().time_msec()) self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) + # XXX why are we doing this here? this function is only run at startup + # so it is odd to re-run it here. self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.pump() - active_count = self.store.get_monthly_active_count() + # the number of users we expect will be counted against the mau limit + # -1 because user3 is a support user and does not count + user_num = len(threepids) - 1 - # Test total counts, ensure user3 (support user) is not counted - self.assertEquals(self.get_success(active_count), user_num) + # Check the number of active users. Ensure user3 (support user) is not counted + active_count = self.get_success(self.store.get_monthly_active_count()) + self.assertEquals(active_count, user_num) - # Test user is marked as active + # Test each of the registered users is marked as active timestamp = self.store.user_last_seen_monthly_active(user1) self.assertTrue(self.get_success(timestamp)) timestamp = self.store.user_last_seen_monthly_active(user2) self.assertTrue(self.get_success(timestamp)) - # Test that users are never removed from the db. + # Test that users with reserved 3pids are not removed from the MAU table + # XXX some of this is redundant. poking things into the config shouldn't + # work, and in any case it's not obvious what we expect to happen when + # we advance the reactor. self.hs.config.max_mau_value = 0 - self.reactor.advance(FORTY_DAYS) self.hs.config.max_mau_value = 5 - self.store.reap_monthly_active_users() self.pump() active_count = self.store.get_monthly_active_count() self.assertEquals(self.get_success(active_count), user_num) - # Test that regular users are removed from the db + # Add some more users and check they are counted as active ru_count = 2 self.store.upsert_monthly_active_user("@ru1:server") self.store.upsert_monthly_active_user("@ru2:server") self.pump() - active_count = self.store.get_monthly_active_count() self.assertEqual(self.get_success(active_count), user_num + ru_count) - self.hs.config.max_mau_value = user_num + + # now run the reaper and check that the number of active users is reduced + # to max_mau_value self.store.reap_monthly_active_users() self.pump() active_count = self.store.get_monthly_active_count() - self.assertEquals(self.get_success(active_count), user_num) + self.assertEquals(self.get_success(active_count), 3) def test_can_insert_and_count_mau(self): count = self.store.get_monthly_active_count() @@ -136,8 +148,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): result = self.store.user_last_seen_monthly_active(user_id3) self.assertNotEqual(self.get_success(result), 0) + @override_config({"max_mau_value": 5}) def test_reap_monthly_active_users(self): - self.hs.config.max_mau_value = 5 initial_users = 10 for i in range(initial_users): self.store.upsert_monthly_active_user("@user%d:server" % i) @@ -158,19 +170,19 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.store.get_monthly_active_count() self.assertEquals(self.get_success(count), 0) + # Note that below says mau_limit (no s), this is the name of the config + # value, although it gets stored on the config object as mau_limits. + @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)}) def test_reap_monthly_active_users_reserved_users(self): """ Tests that reaping correctly handles reaping where reserved users are present""" - - self.hs.config.max_mau_value = 5 - initial_users = 5 + threepids = self.hs.config.mau_limits_reserved_threepids + initial_users = len(threepids) reserved_user_number = initial_users - 1 - threepids = [] for i in range(initial_users): user = "@user%d:server" % i - email = "user%d@example.com" % i + email = "user%d@matrix.org" % i self.get_success(self.store.upsert_monthly_active_user(user)) - threepids.append({"medium": "email", "address": email}) # Need to ensure that the most recent entries in the # monthly_active_users table are reserved now = int(self.hs.get_clock().time_msec()) @@ -182,7 +194,6 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_add_threepid(user, "email", email, now, now) ) - self.hs.config.mau_limits_reserved_threepids = threepids self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) @@ -279,11 +290,11 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.pump() self.assertEqual(self.get_success(count), 0) + # Note that the max_mau_value setting should not matter. + @override_config( + {"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1} + ) def test_track_monthly_users_without_cap(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.mau_stats_only = True - self.hs.config.max_mau_value = 1 # should not matter - count = self.store.get_monthly_active_count() self.assertEqual(0, self.get_success(count)) @@ -294,9 +305,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.store.get_monthly_active_count() self.assertEqual(2, self.get_success(count)) + @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.mau_stats_only = False self.store.upsert_monthly_active_user = Mock() self.store.populate_monthly_active_users("@user:sever") diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 086adeb8fd..3b78d48896 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py
@@ -55,6 +55,17 @@ class RoomStoreTestCase(unittest.TestCase): (yield self.store.get_room(self.room.to_string())), ) + @defer.inlineCallbacks + def test_get_room_with_stats(self): + self.assertDictContainsSubset( + { + "room_id": self.room.to_string(), + "creator": self.u_creator.to_string(), + "public": True, + }, + (yield self.store.get_room_with_stats(self.room.to_string())), + ) + class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 00df0ea68e..5dd46005e6 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py
@@ -22,6 +22,8 @@ from synapse.rest.client.v1 import login, room from synapse.types import Requester, UserID from tests import unittest +from tests.test_utils import event_injection +from tests.utils import TestHomeServer class RoomMemberStoreTestCase(unittest.HomeserverTestCase): @@ -38,7 +40,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ) return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor, clock, hs: TestHomeServer): # We can't test the RoomMemberStore on its own without the other event # storage logic @@ -114,6 +116,52 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # It now knows about Charlie's server. self.assertEqual(self.store._known_servers_count, 2) + def test_get_joined_users_from_context(self): + room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + bob_event = event_injection.inject_member_event( + self.hs, room, self.u_bob, Membership.JOIN + ) + + # first, create a regular event + event, context = event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[bob_event.event_id], + type="m.test.1", + content={}, + ) + + users = self.get_success( + self.store.get_joined_users_from_context(event, context) + ) + self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) + + # Regression test for #7376: create a state event whose key matches bob's + # user_id, but which is *not* a membership event, and persist that; then check + # that `get_joined_users_from_context` returns the correct users for the next event. + non_member_event = event_injection.inject_event( + self.hs, + room_id=room, + sender=self.u_bob, + prev_event_ids=[bob_event.event_id], + type="m.test.2", + state_key=self.u_bob, + content={}, + ) + event, context = event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[non_member_event.event_id], + type="m.test.3", + content={}, + ) + users = self.get_success( + self.store.get_joined_users_from_context(event, context) + ) + self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) + class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver):