summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py14
-rw-r--r--tests/appservice/test_scheduler.py40
-rw-r--r--tests/config/test_base.py21
-rw-r--r--tests/config/test_cache.py70
-rw-r--r--tests/config/test_load.py18
-rw-r--r--tests/config/test_tls.py38
-rw-r--r--tests/events/test_presence_router.py7
-rw-r--r--tests/federation/test_federation_sender.py6
-rw-r--r--tests/federation/test_federation_server.py2
-rw-r--r--tests/handlers/test_profile.py4
-rw-r--r--tests/handlers/test_register.py95
-rw-r--r--tests/handlers/test_stats.py21
-rw-r--r--tests/handlers/test_user_directory.py623
-rw-r--r--tests/http/test_fedclient.py2
-rw-r--r--tests/logging/test_terse_json.py28
-rw-r--r--tests/module_api/test_api.py7
-rw-r--r--tests/replication/_base.py23
-rw-r--r--tests/rest/admin/test_user.py10
-rw-r--r--tests/rest/client/test_account.py55
-rw-r--r--tests/rest/client/test_capabilities.py2
-rw-r--r--tests/rest/client/test_identity.py2
-rw-r--r--tests/rest/client/test_login.py23
-rw-r--r--tests/rest/client/test_presence.py2
-rw-r--r--tests/rest/client/test_register.py8
-rw-r--r--tests/rest/client/test_rooms.py171
-rw-r--r--tests/rest/client/utils.py13
-rw-r--r--tests/rest/media/v1/test_url_preview.py131
-rw-r--r--tests/server.py8
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py2
-rw-r--r--tests/storage/databases/main/test_room.py7
-rw-r--r--tests/storage/test_appservice.py2
-rw-r--r--tests/storage/test_cleanup_extrems.py7
-rw-r--r--tests/storage/test_client_ips.py21
-rw-r--r--tests/storage/test_event_chain.py14
-rw-r--r--tests/storage/test_monthly_active_users.py14
-rw-r--r--tests/storage/test_roommember.py14
-rw-r--r--tests/storage/test_txn_limit.py2
-rw-r--r--tests/storage/test_user_directory.py391
-rw-r--r--tests/test_event_auth.py108
-rw-r--r--tests/test_federation.py1
-rw-r--r--tests/test_mau.py39
-rw-r--r--tests/test_preview.py40
-rw-r--r--tests/unittest.py50
43 files changed, 1563 insertions, 593 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py

index cccff7af26..3aa9ba3c43 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py
@@ -217,7 +217,7 @@ class AuthTestCase(unittest.HomeserverTestCase): user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( - location=self.hs.config.server_name, + location=self.hs.config.server.server_name, identifier="key", key=self.hs.config.key.macaroon_secret_key, ) @@ -239,7 +239,7 @@ class AuthTestCase(unittest.HomeserverTestCase): user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( - location=self.hs.config.server_name, + location=self.hs.config.server.server_name, identifier="key", key=self.hs.config.key.macaroon_secret_key, ) @@ -268,7 +268,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_monthly_active_count = simple_async_mock(lots_of_users) e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.code, 403) @@ -303,7 +303,7 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] @@ -332,7 +332,7 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] @@ -372,7 +372,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.code, 403) @@ -387,7 +387,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.code, 403) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index a2b5ed2030..55f0899bae 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py
@@ -24,7 +24,7 @@ from synapse.appservice.scheduler import ( from synapse.logging.context import make_deferred_yieldable from tests import unittest -from tests.test_utils import make_awaitable +from tests.test_utils import simple_async_mock from ..utils import MockClock @@ -49,11 +49,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): txn = Mock(id=txn_id, service=service, events=events) # mock methods - self.store.get_appservice_state = Mock( - return_value=defer.succeed(ApplicationServiceState.UP) - ) - txn.send = Mock(return_value=make_awaitable(True)) - self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) + self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) + txn.send = simple_async_mock(True) + txn.complete = simple_async_mock(True) + self.store.create_appservice_txn = simple_async_mock(txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -71,10 +70,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events = [Mock(), Mock()] txn = Mock(id="idhere", service=service, events=events) - self.store.get_appservice_state = Mock( - return_value=defer.succeed(ApplicationServiceState.DOWN) + self.store.get_appservice_state = simple_async_mock( + ApplicationServiceState.DOWN ) - self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) + self.store.create_appservice_txn = simple_async_mock(txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -94,12 +93,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): txn = Mock(id=txn_id, service=service, events=events) # mock methods - self.store.get_appservice_state = Mock( - return_value=defer.succeed(ApplicationServiceState.UP) - ) - self.store.set_appservice_state = Mock(return_value=defer.succeed(True)) - txn.send = Mock(return_value=make_awaitable(False)) # fails to send - self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) + self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) + self.store.set_appservice_state = simple_async_mock(True) + txn.send = simple_async_mock(False) # fails to send + self.store.create_appservice_txn = simple_async_mock(txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -122,7 +119,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.as_api = Mock() self.store = Mock() self.service = Mock() - self.callback = Mock() + self.callback = simple_async_mock() self.recoverer = _Recoverer( clock=self.clock, as_api=self.as_api, @@ -144,8 +141,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() # shouldn't have called anything prior to waiting for exp backoff self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = Mock(return_value=make_awaitable(True)) - txn.complete.return_value = make_awaitable(None) + txn.send = simple_async_mock(True) + txn.complete = simple_async_mock(None) # wait for exp backoff self.clock.advance_time(2) self.assertEquals(1, txn.send.call_count) @@ -170,8 +167,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = Mock(return_value=make_awaitable(False)) - txn.complete.return_value = make_awaitable(None) + txn.send = simple_async_mock(False) + txn.complete = simple_async_mock(None) self.clock.advance_time(2) self.assertEquals(1, txn.send.call_count) self.assertEquals(0, txn.complete.call_count) @@ -184,7 +181,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.assertEquals(3, txn.send.call_count) self.assertEquals(0, txn.complete.call_count) self.assertEquals(0, self.callback.call_count) - txn.send = Mock(return_value=make_awaitable(True)) # successfully send the txn + txn.send = simple_async_mock(True) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. self.clock.advance_time(16) self.assertEquals(1, txn.send.call_count) # new mock reset call count @@ -195,6 +192,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): def setUp(self): self.txn_ctrl = Mock() + self.txn_ctrl.send = simple_async_mock() self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock()) def test_send_single_event_no_queue(self): diff --git a/tests/config/test_base.py b/tests/config/test_base.py
index baa5313fb3..6a52f862f4 100644 --- a/tests/config/test_base.py +++ b/tests/config/test_base.py
@@ -14,23 +14,28 @@ import os.path import tempfile +from unittest.mock import Mock from synapse.config import ConfigError +from synapse.config._base import Config from synapse.util.stringutils import random_string from tests import unittest -class BaseConfigTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): - self.hs = hs +class BaseConfigTestCase(unittest.TestCase): + def setUp(self): + # The root object needs a server property with a public_baseurl. + root = Mock() + root.server.public_baseurl = "http://test" + self.config = Config(root) def test_loading_missing_templates(self): # Use a temporary directory that exists on the system, but that isn't likely to # contain template files with tempfile.TemporaryDirectory() as tmp_dir: # Attempt to load an HTML template from our custom template directory - template = self.hs.config.read_templates(["sso_error.html"], (tmp_dir,))[0] + template = self.config.read_templates(["sso_error.html"], (tmp_dir,))[0] # If no errors, we should've gotten the default template instead @@ -60,7 +65,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Attempt to load the template from our custom template directory template = ( - self.hs.config.read_templates([template_filename], (tmp_dir,)) + self.config.read_templates([template_filename], (tmp_dir,)) )[0] # Render the template @@ -97,7 +102,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Retrieve the template. template = ( - self.hs.config.read_templates( + self.config.read_templates( [template_filename], (td.name for td in tempdirs), ) @@ -118,7 +123,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Retrieve the template. template = ( - self.hs.config.read_templates( + self.config.read_templates( [other_template_name], (td.name for td in tempdirs), ) @@ -134,6 +139,6 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): def test_loading_template_from_nonexistent_custom_directory(self): with self.assertRaises(ConfigError): - self.hs.config.read_templates( + self.config.read_templates( ["some_filename.html"], ("a_nonexistent_directory",) ) diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
index 857d9cd096..79d417568d 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py
@@ -12,59 +12,55 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.config._base import Config, RootConfig +from unittest.mock import patch + from synapse.config.cache import CacheConfig, add_resizable_cache from synapse.util.caches.lrucache import LruCache from tests.unittest import TestCase -class FakeServer(Config): - section = "server" - - -class TestConfig(RootConfig): - config_classes = [FakeServer, CacheConfig] - - +# Patch the global _CACHES so that each test runs against its own state. +@patch("synapse.config.cache._CACHES", new_callable=dict) class CacheConfigTests(TestCase): def setUp(self): # Reset caches before each test - TestConfig().caches.reset() + self.config = CacheConfig() + + def tearDown(self): + self.config.reset() - def test_individual_caches_from_environ(self): + def test_individual_caches_from_environ(self, _caches): """ Individual cache factors will be loaded from the environment. """ config = {} - t = TestConfig() - t.caches._environ = { + self.config._environ = { "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", "SYNAPSE_NOT_CACHE": "BLAH", } - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0}) + self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0}) - def test_config_overrides_environ(self): + def test_config_overrides_environ(self, _caches): """ Individual cache factors defined in the environment will take precedence over those in the config. """ config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} - t = TestConfig() - t.caches._environ = { + self.config._environ = { "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", "SYNAPSE_CACHE_FACTOR_FOO": 1, } - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual( - dict(t.caches.cache_factors), + dict(self.config.cache_factors), {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, ) - def test_individual_instantiated_before_config_load(self): + def test_individual_instantiated_before_config_load(self, _caches): """ If a cache is instantiated before the config is read, it will be given the default cache size in the interim, and then resized once the config @@ -76,26 +72,24 @@ class CacheConfigTests(TestCase): self.assertEqual(cache.max_size, 50) config = {"caches": {"per_cache_factors": {"foo": 3}}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config) self.assertEqual(cache.max_size, 300) - def test_individual_instantiated_after_config_load(self): + def test_individual_instantiated_after_config_load(self, _caches): """ If a cache is instantiated after the config is read, it will be immediately resized to the correct size given the per_cache_factor if there is one. """ config = {"caches": {"per_cache_factors": {"foo": 2}}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 200) - def test_global_instantiated_before_config_load(self): + def test_global_instantiated_before_config_load(self, _caches): """ If a cache is instantiated before the config is read, it will be given the default cache size in the interim, and then resized to the new @@ -106,26 +100,24 @@ class CacheConfigTests(TestCase): self.assertEqual(cache.max_size, 50) config = {"caches": {"global_factor": 4}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual(cache.max_size, 400) - def test_global_instantiated_after_config_load(self): + def test_global_instantiated_after_config_load(self, _caches): """ If a cache is instantiated after the config is read, it will be immediately resized to the correct size given the global factor if there is no per-cache factor. """ config = {"caches": {"global_factor": 1.5}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 150) - def test_cache_with_asterisk_in_name(self): + def test_cache_with_asterisk_in_name(self, _caches): """Some caches have asterisks in their name, test that they are set correctly.""" config = { @@ -133,12 +125,11 @@ class CacheConfigTests(TestCase): "per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2} } } - t = TestConfig() - t.caches._environ = { + self.config._environ = { "SYNAPSE_CACHE_FACTOR_CACHE_A": "2", "SYNAPSE_CACHE_FACTOR_CACHE_B": 3, } - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache_a = LruCache(100) add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor) @@ -152,17 +143,16 @@ class CacheConfigTests(TestCase): add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor) self.assertEqual(cache_c.max_size, 200) - def test_apply_cache_factor_from_config(self): + def test_apply_cache_factor_from_config(self, _caches): """Caches can disable applying cache factor updates, mainly used by event cache size. """ config = {"caches": {"event_cache_size": "10k"}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache( - max_size=t.caches.event_cache_size, + max_size=self.config.event_cache_size, apply_cache_factor_from_config=False, ) add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor) diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index ef6c2beec7..59635de205 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py
@@ -49,7 +49,7 @@ class ConfigLoadingTestCase(unittest.TestCase): config = HomeServerConfig.load_config("", ["-c", self.file]) self.assertTrue( - hasattr(config, "macaroon_secret_key"), + hasattr(config.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) if len(config.key.macaroon_secret_key) < 5: @@ -60,7 +60,7 @@ class ConfigLoadingTestCase(unittest.TestCase): config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) self.assertTrue( - hasattr(config, "macaroon_secret_key"), + hasattr(config.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) if len(config.key.macaroon_secret_key) < 5: @@ -74,8 +74,12 @@ class ConfigLoadingTestCase(unittest.TestCase): config1 = HomeServerConfig.load_config("", ["-c", self.file]) config2 = HomeServerConfig.load_config("", ["-c", self.file]) config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) - self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key) - self.assertEqual(config1.macaroon_secret_key, config3.macaroon_secret_key) + self.assertEqual( + config1.key.macaroon_secret_key, config2.key.macaroon_secret_key + ) + self.assertEqual( + config1.key.macaroon_secret_key, config3.key.macaroon_secret_key + ) def test_disable_registration(self): self.generate_config() @@ -84,16 +88,16 @@ class ConfigLoadingTestCase(unittest.TestCase): ) # Check that disable_registration clobbers enable_registration. config = HomeServerConfig.load_config("", ["-c", self.file]) - self.assertFalse(config.enable_registration) + self.assertFalse(config.registration.enable_registration) config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) - self.assertFalse(config.enable_registration) + self.assertFalse(config.registration.enable_registration) # Check that either config value is clobbered by the command line. config = HomeServerConfig.load_or_generate_config( "", ["-c", self.file, "--enable-registration"] ) - self.assertTrue(config.enable_registration) + self.assertTrue(config.registration.enable_registration) def test_stats_enabled(self): self.generate_config_and_remove_lines_containing("enable_metrics") diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index b6bc1876b5..9ba5781573 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py
@@ -42,9 +42,9 @@ class TLSConfigTests(TestCase): """ config = {} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") def test_tls_client_minimum_set(self): """ @@ -52,29 +52,29 @@ class TLSConfigTests(TestCase): """ config = {"federation_client_minimum_tls_version": 1} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") config = {"federation_client_minimum_tls_version": 1.1} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.1") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.1") config = {"federation_client_minimum_tls_version": 1.2} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.2") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2") # Also test a string version config = {"federation_client_minimum_tls_version": "1"} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") config = {"federation_client_minimum_tls_version": "1.2"} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.2") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2") def test_tls_client_minimum_1_point_3_missing(self): """ @@ -91,7 +91,7 @@ class TLSConfigTests(TestCase): config = {"federation_client_minimum_tls_version": 1.3} t = TestConfig() with self.assertRaises(ConfigError) as e: - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual( e.exception.args[0], ( @@ -112,8 +112,8 @@ class TLSConfigTests(TestCase): config = {"federation_client_minimum_tls_version": 1.3} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.3") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3") def test_tls_client_minimum_set_passed_through_1_2(self): """ @@ -121,7 +121,7 @@ class TLSConfigTests(TestCase): """ config = {"federation_client_minimum_tls_version": 1.2} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") cf = FederationPolicyForHTTPS(t) options = _get_ssl_context_options(cf._verify_ssl_context) @@ -137,7 +137,7 @@ class TLSConfigTests(TestCase): """ config = {"federation_client_minimum_tls_version": 1} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") cf = FederationPolicyForHTTPS(t) options = _get_ssl_context_options(cf._verify_ssl_context) @@ -159,7 +159,7 @@ class TLSConfigTests(TestCase): } t = TestConfig() e = self.assertRaises( - ConfigError, t.read_config, config, config_dir_path="", data_dir_path="" + ConfigError, t.tls.read_config, config, config_dir_path="", data_dir_path="" ) self.assertIn("IDNA domain names", str(e)) @@ -174,7 +174,7 @@ class TLSConfigTests(TestCase): ] } t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") cf = FederationPolicyForHTTPS(t) diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 3b3866bff8..3deb14c308 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py
@@ -26,6 +26,7 @@ from synapse.rest.client import login, presence, room from synapse.types import JsonDict, StreamToken, create_requester from tests.handlers.test_sync import generate_sync_config +from tests.test_utils import simple_async_mock from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config @@ -133,8 +134,12 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ] def make_homeserver(self, reactor, clock): + # Mock out the calls over federation. + fed_transport_client = Mock(spec=["send_transaction"]) + fed_transport_client.send_transaction = simple_async_mock({}) + hs = self.setup_test_homeserver( - federation_transport_client=Mock(spec=["send_transaction"]), + federation_transport_client=fed_transport_client, ) # Load the modules into the homeserver module_api = hs.get_module_api() diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 65b18fbd7a..b457dad6d2 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py
@@ -336,7 +336,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): recovery """ mock_send_txn = self.hs.get_federation_transport_client().send_transaction - mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) # create devices u1 = self.register_user("user", "pass") @@ -376,7 +376,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): This case tests the behaviour when the server has never been reachable. """ mock_send_txn = self.hs.get_federation_transport_client().send_transaction - mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) # create devices u1 = self.register_user("user", "pass") @@ -429,7 +429,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # now the server goes offline mock_send_txn = self.hs.get_federation_transport_client().send_transaction - mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) self.login("user", "pass", device_id="D2") self.login("user", "pass", device_id="D3") diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 0b60cc4261..03e1e11f49 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py
@@ -120,7 +120,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): self.assertEqual( channel.json_body["room_version"], - self.hs.config.default_room_version.identifier, + self.hs.config.server.default_room_version.identifier, ) members = set( diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 57cc3e2646..c153018fd8 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py
@@ -110,7 +110,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) def test_set_my_name_if_disabled(self): - self.hs.config.enable_set_displayname = False + self.hs.config.registration.enable_set_displayname = False # Setting displayname for the first time is allowed self.get_success( @@ -225,7 +225,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) def test_set_my_avatar_if_disabled(self): - self.hs.config.enable_set_avatar_url = False + self.hs.config.registration.enable_set_avatar_url = False # Setting displayname for the first time is allowed self.get_success( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index d3efb67e3e..db691c4c1c 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py
@@ -16,7 +16,12 @@ from unittest.mock import Mock from synapse.api.auth import Auth from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, ResourceLimitError, SynapseError +from synapse.api.errors import ( + CodeMessageException, + Codes, + ResourceLimitError, + SynapseError, +) from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, RoomID, UserID, create_requester @@ -120,14 +125,24 @@ class RegistrationTestCase(unittest.HomeserverTestCase): hs_config = self.default_config() # some of the tests rely on us having a user consent version - hs_config["user_consent"] = { - "version": "test_consent_version", - "template_dir": ".", - } + hs_config.setdefault("user_consent", {}).update( + { + "version": "test_consent_version", + "template_dir": ".", + } + ) hs_config["max_mau_value"] = 50 hs_config["limit_usage_by_mau"] = True - hs = self.setup_test_homeserver(config=hs_config) + # Don't attempt to reach out over federation. + self.mock_federation_client = Mock() + self.mock_federation_client.make_query.side_effect = CodeMessageException( + 500, "" + ) + + hs = self.setup_test_homeserver( + config=hs_config, federation_client=self.mock_federation_client + ) load_legacy_spam_checkers(hs) @@ -138,9 +153,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs): - self.mock_distributor = Mock() - self.mock_distributor.declare("registered_user") - self.mock_captcha_client = Mock() self.handler = self.hs.get_registration_handler() self.store = self.hs.get_datastore() self.lots_of_users = 100 @@ -174,21 +186,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) + @override_config({"limit_usage_by_mau": False}) def test_mau_limits_when_disabled(self): - self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "a", "display_name")) + @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_not_blocked(self): - self.hs.config.limit_usage_by_mau = True self.store.count_monthly_users = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value - 1) + return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "c", "User")) + @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_blocked(self): - self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) @@ -198,15 +210,15 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), ResourceLimitError, ) + @override_config({"limit_usage_by_mau": True}) def test_register_mau_blocked(self): - self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) @@ -215,16 +227,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError ) + @override_config( + {"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False} + ) def test_auto_join_rooms_for_guests(self): - room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] - self.hs.config.auto_join_rooms_for_guests = False user_id = self.get_success( self.handler.register_user(localpart="jeff", make_guest=True), ) @@ -243,34 +255,33 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertTrue(room_id["room_id"] in rooms) self.assertEqual(len(rooms), 1) + @override_config({"auto_join_rooms": []}) def test_auto_create_auto_join_rooms_with_no_rooms(self): - self.hs.config.auto_join_rooms = [] frank = UserID.from_string("@frank:test") user_id = self.get_success(self.handler.register_user(frank.localpart)) self.assertEqual(user_id, frank.to_string()) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + @override_config({"auto_join_rooms": ["#room:another"]}) def test_auto_create_auto_join_where_room_is_another_domain(self): - self.hs.config.auto_join_rooms = ["#room:another"] frank = UserID.from_string("@frank:test") user_id = self.get_success(self.handler.register_user(frank.localpart)) self.assertEqual(user_id, frank.to_string()) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + @override_config( + {"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False} + ) def test_auto_create_auto_join_where_auto_create_is_false(self): - self.hs.config.autocreate_auto_join_rooms = False - room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] user_id = self.get_success(self.handler.register_user(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + @override_config({"auto_join_rooms": ["#room:test"]}) def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self): room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_real_user = Mock(return_value=make_awaitable(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) @@ -294,10 +305,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertTrue(room_id["room_id"] in rooms) self.assertEqual(len(rooms), 1) + @override_config({"auto_join_rooms": ["#room:test"]}) def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self): - room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=make_awaitable(2)) self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) @@ -510,6 +519,17 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertEqual(rooms, set()) self.assertEqual(invited_rooms, []) + @override_config( + { + "user_consent": { + "block_events_error": "Error", + "require_at_registration": True, + }, + "form_secret": "53cr3t", + "public_baseurl": "http://test", + "auto_join_rooms": ["#room:test"], + }, + ) def test_auto_create_auto_join_where_no_consent(self): """Test to ensure that the first user is not auto-joined to a room if they have not given general consent. @@ -521,25 +541,20 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # * The server is configured to auto-join to a room # (and autocreate if necessary) - event_creation_handler = self.hs.get_event_creation_handler() - # (Messing with the internals of event_creation_handler is fragile - # but can't see a better way to do this. One option could be to subclass - # the test with custom config.) - event_creation_handler._block_events_without_consent_error = "Error" - event_creation_handler._consent_uri_builder = Mock() - room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] - # When:- - # * the user is registered and post consent actions are called + # * the user is registered user_id = self.get_success(self.handler.register_user(localpart="jeff")) - self.get_success(self.handler.post_consent_actions(user_id)) # Then:- # * Ensure that they have not been joined to the room rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + # The user provides consent; ensure they are now in the rooms. + self.get_success(self.handler.post_consent_actions(user_id)) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertEqual(len(rooms), 1) + def test_register_support_user(self): user_id = self.get_success( self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 24b7ef6efc..56207f4db6 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py
@@ -103,12 +103,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the stats via the background update self._add_background_updates() - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() def test_initial_room(self): """ @@ -140,12 +135,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() r = self.get_success(self.get_all_room_state()) @@ -568,12 +558,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() r1stats_complete = self._get_current_stats("room", r1) u1stats_complete = self._get_current_stats("user", u1) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 266333c553..db65253773 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py
@@ -11,47 +11,208 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple -from unittest.mock import Mock +from typing import Tuple +from unittest.mock import Mock, patch from urllib.parse import quote from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.rest.client import login, room, user_directory +from synapse.appservice import ApplicationService +from synapse.rest.client import login, register, room, user_directory +from synapse.server import HomeServer from synapse.storage.roommember import ProfileInfo from synapse.types import create_requester +from synapse.util import Clock from tests import unittest +from tests.storage.test_user_directory import GetUserDirectoryTables +from tests.test_utils.event_injection import inject_member_event from tests.unittest import override_config class UserDirectoryTestCase(unittest.HomeserverTestCase): - """ - Tests the UserDirectoryHandler. + """Tests the UserDirectoryHandler. + + We're broadly testing two kinds of things here. + + 1. Check that we correctly update the user directory in response + to events (e.g. join a room, leave a room, change name, make public) + 2. Check that the search logic behaves as expected. + + The background process that rebuilds the user directory is tested in + tests/storage/test_user_directory.py. """ servlets = [ login.register_servlets, synapse.rest.admin.register_servlets, + register.register_servlets, room.register_servlets, ] - def make_homeserver(self, reactor, clock): - + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["update_user_directory"] = True - return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + self.appservice = ApplicationService( + token="i_am_an_app_service", + hostname="test", + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", + ) + + mock_load_appservices = Mock(return_value=[self.appservice]) + with patch( + "synapse.storage.databases.main.appservice.load_appservices", + mock_load_appservices, + ): + hs = self.setup_test_homeserver(config=config) + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.handler = hs.get_user_directory_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.event_creation_handler = self.hs.get_event_creation_handler() + self.user_dir_helper = GetUserDirectoryTables(self.store) + + def test_normal_user_pair(self) -> None: + """Sanity check that the room-sharing tables are updated correctly.""" + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + + public = self.helper.create_room_as( + alice, + is_public=True, + extra_content={"visibility": "public"}, + tok=alice_token, + ) + private = self.helper.create_room_as(alice, is_public=False, tok=alice_token) + self.helper.invite(private, alice, bob, tok=alice_token) + self.helper.join(public, bob, tok=bob_token) + self.helper.join(private, bob, tok=bob_token) + + # Alice also makes a second public room but no-one else joins + public2 = self.helper.create_room_as( + alice, + is_public=True, + extra_content={"visibility": "public"}, + tok=alice_token, + ) + + users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) + in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + in_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + + self.assertEqual(users, {alice, bob}) + self.assertEqual( + set(in_public), {(alice, public), (bob, public), (alice, public2)} + ) + self.assertEqual( + self.user_dir_helper._compress_shared(in_private), + {(alice, bob, private), (bob, alice, private)}, + ) + + # The next three tests (test_population_excludes_*) all setup + # - A normal user included in the user dir + # - A public and private room created by that user + # - A user excluded from the room dir, belonging to both rooms + + # They match similar logic in storage/test_user_directory. But that tests + # rebuilding the directory; this tests updating it incrementally. + + def test_excludes_support_user(self) -> None: + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + support = "@support1:test" + self.get_success( + self.store.register_user( + user_id=support, password_hash=None, user_type=UserTypes.SUPPORT + ) + ) + + public, private = self._create_rooms_and_inject_memberships( + alice, alice_token, support + ) + self._check_only_one_user_in_directory(alice, public) + + def test_excludes_deactivated_user(self) -> None: + admin = self.register_user("admin", "pass", admin=True) + admin_token = self.login(admin, "pass") + user = self.register_user("naughty", "pass") + + # Deactivate the user. + channel = self.make_request( + "PUT", + f"/_synapse/admin/v2/users/{user}", + access_token=admin_token, + content={"deactivated": True}, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["deactivated"], True) + + # Join the deactivated user to rooms owned by the admin. + # Is this something that could actually happen outside of a test? + public, private = self._create_rooms_and_inject_memberships( + admin, admin_token, user + ) + self._check_only_one_user_in_directory(admin, public) + + def test_excludes_appservices_user(self) -> None: + # Register an AS user. + user = self.register_user("user", "pass") + token = self.login(user, "pass") + as_user = self.register_appservice_user("as_user_potato", self.appservice.token) + + # Join the AS user to rooms owned by the normal user. + public, private = self._create_rooms_and_inject_memberships( + user, token, as_user + ) + self._check_only_one_user_in_directory(user, public) + + def _create_rooms_and_inject_memberships( + self, creator: str, token: str, joiner: str + ) -> Tuple[str, str]: + """Create a public and private room as a normal user. + Then get the `joiner` into those rooms. + """ + # TODO: Duplicates the same-named method in UserDirectoryInitialPopulationTest. + public_room = self.helper.create_room_as( + creator, + is_public=True, + # See https://github.com/matrix-org/synapse/issues/10951 + extra_content={"visibility": "public"}, + tok=token, + ) + private_room = self.helper.create_room_as(creator, is_public=False, tok=token) + + # HACK: get the user into these rooms + self.get_success(inject_member_event(self.hs, public_room, joiner, "join")) + self.get_success(inject_member_event(self.hs, private_room, joiner, "join")) + + return public_room, private_room + + def _check_only_one_user_in_directory(self, user: str, public: str) -> None: + users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) + in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + in_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) - def test_handle_local_profile_change_with_support_user(self): + self.assertEqual(users, {user}) + self.assertEqual(set(in_public), {(user, public)}) + self.assertEqual(in_private, []) + + def test_handle_local_profile_change_with_support_user(self) -> None: support_user_id = "@support:test" self.get_success( self.store.register_user( @@ -64,7 +225,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.handler.handle_local_profile_change(support_user_id, None) + self.handler.handle_local_profile_change( + support_user_id, ProfileInfo("I love support me", None) + ) ) profile = self.get_success(self.store.get_user_in_directory(support_user_id)) self.assertTrue(profile is None) @@ -77,7 +240,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): profile = self.get_success(self.store.get_user_in_directory(regular_user_id)) self.assertTrue(profile["display_name"] == display_name) - def test_handle_local_profile_change_with_deactivated_user(self): + def test_handle_local_profile_change_with_deactivated_user(self) -> None: # create user r_user_id = "@regular:test" self.get_success( @@ -112,7 +275,27 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): profile = self.get_success(self.store.get_user_in_directory(r_user_id)) self.assertTrue(profile is None) - def test_handle_user_deactivated_support_user(self): + def test_handle_local_profile_change_with_appservice_user(self) -> None: + # create user + as_user_id = self.register_appservice_user( + "as_user_alice", self.appservice.token + ) + + # profile is not in directory + profile = self.get_success(self.store.get_user_in_directory(as_user_id)) + self.assertTrue(profile is None) + + # update profile + profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3") + self.get_success( + self.handler.handle_local_profile_change(as_user_id, profile_info) + ) + + # profile is still not in directory + profile = self.get_success(self.store.get_user_in_directory(as_user_id)) + self.assertTrue(profile is None) + + def test_handle_user_deactivated_support_user(self) -> None: s_user_id = "@support:test" self.get_success( self.store.register_user( @@ -120,20 +303,29 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) - self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None)) - self.get_success(self.handler.handle_local_user_deactivated(s_user_id)) - self.store.remove_from_user_dir.not_called() + mock_remove_from_user_dir = Mock(return_value=defer.succeed(None)) + with patch.object( + self.store, "remove_from_user_dir", mock_remove_from_user_dir + ): + self.get_success(self.handler.handle_local_user_deactivated(s_user_id)) + # BUG: the correct spelling is assert_not_called, but that makes the test fail + # and it's not clear that this is actually the behaviour we want. + mock_remove_from_user_dir.not_called() - def test_handle_user_deactivated_regular_user(self): + def test_handle_user_deactivated_regular_user(self) -> None: r_user_id = "@regular:test" self.get_success( self.store.register_user(user_id=r_user_id, password_hash=None) ) - self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None)) - self.get_success(self.handler.handle_local_user_deactivated(r_user_id)) - self.store.remove_from_user_dir.called_once_with(r_user_id) - def test_reactivation_makes_regular_user_searchable(self): + mock_remove_from_user_dir = Mock(return_value=defer.succeed(None)) + with patch.object( + self.store, "remove_from_user_dir", mock_remove_from_user_dir + ): + self.get_success(self.handler.handle_local_user_deactivated(r_user_id)) + mock_remove_from_user_dir.assert_called_once_with(r_user_id) + + def test_reactivation_makes_regular_user_searchable(self) -> None: user = self.register_user("regular", "pass") user_token = self.login(user, "pass") admin_user = self.register_user("admin", "pass", admin=True) @@ -171,7 +363,147 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(len(s["results"]), 1) self.assertEqual(s["results"][0]["user_id"], user) - def test_private_room(self): + def test_process_join_after_server_leaves_room(self) -> None: + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + + # Alice makes two rooms. Bob joins one of them. + room1 = self.helper.create_room_as(alice, tok=alice_token) + room2 = self.helper.create_room_as(alice, tok=alice_token) + self.helper.join(room1, bob, tok=bob_token) + + # The user sharing tables should have been updated. + public1 = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + self.assertEqual(set(public1), {(alice, room1), (alice, room2), (bob, room1)}) + + # Alice leaves room1. The user sharing tables should be updated. + self.helper.leave(room1, alice, tok=alice_token) + public2 = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + self.assertEqual(set(public2), {(alice, room2), (bob, room1)}) + + # Pause the processing of new events. + dir_handler = self.hs.get_user_directory_handler() + dir_handler.update_user_directory = False + + # Bob leaves one room and joins the other. + self.helper.leave(room1, bob, tok=bob_token) + self.helper.join(room2, bob, tok=bob_token) + + # Process the leave and join in one go. + dir_handler.update_user_directory = True + dir_handler.notify_new_event() + self.wait_for_background_updates() + + # The user sharing tables should have been updated. + public3 = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + self.assertEqual(set(public3), {(alice, room2), (bob, room2)}) + + def test_per_room_profile_doesnt_alter_directory_entry(self) -> None: + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + + # Alice should have a user directory entry created at registration. + users = self.get_success(self.user_dir_helper.get_profiles_in_user_directory()) + self.assertEqual( + users[alice], ProfileInfo(display_name="alice", avatar_url=None) + ) + + # Alice makes a room for herself. + room = self.helper.create_room_as(alice, is_public=True, tok=alice_token) + + # Alice sets a nickname unique to that room. + self.helper.send_state( + room, + "m.room.member", + { + "displayname": "Freddy Mercury", + "membership": "join", + }, + alice_token, + state_key=alice, + ) + + # Alice's display name remains the same in the user directory. + search_result = self.get_success(self.handler.search_users(bob, alice, 10)) + self.assertEqual( + search_result["results"], + [{"display_name": "alice", "avatar_url": None, "user_id": alice}], + 0, + ) + + def test_making_room_public_doesnt_alter_directory_entry(self) -> None: + """Per-room names shouldn't go to the directory when the room becomes public. + + This isn't about preventing a leak (the room is now public, so the nickname + is too). It's about preserving the invariant that we only show a user's public + profile in the user directory results. + + I made this a Synapse test case rather than a Complement one because + I think this is (strictly speaking) an implementation choice. Synapse + has chosen to only ever use the public profile when responding to a user + directory search. There's no privacy leak here, because making the room + public discloses the per-room name. + + The spec doesn't mandate anything about _how_ a user + should appear in a /user_directory/search result. Hypothetical example: + suppose Bob searches for Alice. When representing Alice in a search + result, it's reasonable to use any of Alice's nicknames that Bob is + aware of. Heck, maybe we even want to use lots of them in a combined + displayname like `Alice (aka "ali", "ally", "41iC3")`. + """ + + # TODO the same should apply when Alice is a remote user. + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + + # Alice and Bob are in a private room. + room = self.helper.create_room_as(alice, is_public=False, tok=alice_token) + self.helper.invite(room, src=alice, targ=bob, tok=alice_token) + self.helper.join(room, user=bob, tok=bob_token) + + # Alice has a nickname unique to that room. + + self.helper.send_state( + room, + "m.room.member", + { + "displayname": "Freddy Mercury", + "membership": "join", + }, + alice_token, + state_key=alice, + ) + + # Check Alice isn't recorded as being in a public room. + public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + self.assertNotIn((alice, room), public) + + # One of them makes the room public. + self.helper.send_state( + room, + "m.room.join_rules", + {"join_rule": "public"}, + alice_token, + ) + + # Check that Alice is now recorded as being in a public room + public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + self.assertIn((alice, room), public) + + # Alice's display name remains the same in the user directory. + search_result = self.get_success(self.handler.search_users(bob, alice, 10)) + self.assertEqual( + search_result["results"], + [{"display_name": "alice", "avatar_url": None, "user_id": alice}], + 0, + ) + + def test_private_room(self) -> None: """ A user can be searched for only by people that are either in a public room, or that share a private chat. @@ -191,11 +523,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.helper.join(room, user=u2, tok=u2_token) # Check we have populated the database correctly. - shares_private = self.get_users_who_share_private_rooms() - public_users = self.get_users_in_public_rooms() + shares_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + public_users = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) self.assertEqual( - self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} + self.user_dir_helper._compress_shared(shares_private), + {(u1, u2, room), (u2, u1, room)}, ) self.assertEqual(public_users, []) @@ -215,10 +552,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.helper.leave(room, user=u2, tok=u2_token) # Check we have removed the values. - shares_private = self.get_users_who_share_private_rooms() - public_users = self.get_users_in_public_rooms() + shares_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + public_users = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) - self.assertEqual(self._compress_shared(shares_private), set()) + self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set()) self.assertEqual(public_users, []) # User1 now gets no search results for any of the other users. @@ -228,7 +569,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user3", 10)) self.assertEqual(len(s["results"]), 0) - def test_spam_checker(self): + def test_spam_checker(self) -> None: """ A user which fails the spam checks will not appear in search results. """ @@ -246,11 +587,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.helper.join(room, user=u2, tok=u2_token) # Check we have populated the database correctly. - shares_private = self.get_users_who_share_private_rooms() - public_users = self.get_users_in_public_rooms() + shares_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + public_users = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) self.assertEqual( - self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} + self.user_dir_helper._compress_shared(shares_private), + {(u1, u2, room), (u2, u1, room)}, ) self.assertEqual(public_users, []) @@ -258,7 +604,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 1) - async def allow_all(user_profile): + async def allow_all(user_profile: ProfileInfo) -> bool: # Allow all users. return False @@ -272,7 +618,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(len(s["results"]), 1) # Configure a spam checker that filters all users. - async def block_all(user_profile): + async def block_all(user_profile: ProfileInfo) -> bool: # All users are spammy. return True @@ -282,7 +628,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 0) - def test_legacy_spam_checker(self): + def test_legacy_spam_checker(self) -> None: """ A spam checker without the expected method should be ignored. """ @@ -300,11 +646,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.helper.join(room, user=u2, tok=u2_token) # Check we have populated the database correctly. - shares_private = self.get_users_who_share_private_rooms() - public_users = self.get_users_in_public_rooms() + shares_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + public_users = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) self.assertEqual( - self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} + self.user_dir_helper._compress_shared(shares_private), + {(u1, u2, room), (u2, u1, room)}, ) self.assertEqual(public_users, []) @@ -317,134 +668,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 1) - def _compress_shared(self, shared): - """ - Compress a list of users who share rooms dicts to a list of tuples. - """ - r = set() - for i in shared: - r.add((i["user_id"], i["other_user_id"], i["room_id"])) - return r - - def get_users_in_public_rooms(self) -> List[Tuple[str, str]]: - r = self.get_success( - self.store.db_pool.simple_select_list( - "users_in_public_rooms", None, ("user_id", "room_id") - ) - ) - retval = [] - for i in r: - retval.append((i["user_id"], i["room_id"])) - return retval - - def get_users_who_share_private_rooms(self) -> List[Tuple[str, str, str]]: - return self.get_success( - self.store.db_pool.simple_select_list( - "users_who_share_private_rooms", - None, - ["user_id", "other_user_id", "room_id"], - ) - ) - - def _add_background_updates(self): - """ - Add the background updates we need to run. - """ - # Ugh, have to reset this flag - self.store.db_pool.updates._all_done = False - - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": "populate_user_directory_createtables", - "progress_json": "{}", - }, - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": "populate_user_directory_process_rooms", - "progress_json": "{}", - "depends_on": "populate_user_directory_createtables", - }, - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": "populate_user_directory_process_users", - "progress_json": "{}", - "depends_on": "populate_user_directory_process_rooms", - }, - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": "populate_user_directory_cleanup", - "progress_json": "{}", - "depends_on": "populate_user_directory_process_users", - }, - ) - ) - - def test_initial(self): - """ - The user directory's initial handler correctly updates the search tables. - """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - u3 = self.register_user("user3", "pass") - u3_token = self.login(u3, "pass") - - room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) - self.helper.invite(room, src=u1, targ=u2, tok=u1_token) - self.helper.join(room, user=u2, tok=u2_token) - - private_room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) - self.helper.invite(private_room, src=u1, targ=u3, tok=u1_token) - self.helper.join(private_room, user=u3, tok=u3_token) - - self.get_success(self.store.update_user_directory_stream_pos(None)) - self.get_success(self.store.delete_all_from_user_dir()) - - shares_private = self.get_users_who_share_private_rooms() - public_users = self.get_users_in_public_rooms() - - # Nothing updated yet - self.assertEqual(shares_private, []) - self.assertEqual(public_users, []) - - # Do the initial population of the user directory via the background update - self._add_background_updates() - - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) - - shares_private = self.get_users_who_share_private_rooms() - public_users = self.get_users_in_public_rooms() - - # User 1 and User 2 are in the same public room - self.assertEqual(set(public_users), {(u1, room), (u2, room)}) - - # User 1 and User 3 share private rooms - self.assertEqual( - self._compress_shared(shares_private), - {(u1, u3, private_room), (u3, u1, private_room)}, - ) - - def test_initial_share_all_users(self): + def test_initial_share_all_users(self) -> None: """ Search all users = True means that a user does not have to share a private room with the searching user or be in a public room to be search @@ -457,26 +681,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.register_user("user2", "pass") u3 = self.register_user("user3", "pass") - # Wipe the user dir - self.get_success(self.store.update_user_directory_stream_pos(None)) - self.get_success(self.store.delete_all_from_user_dir()) - - # Do the initial population of the user directory via the background update - self._add_background_updates() - - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) - - shares_private = self.get_users_who_share_private_rooms() - public_users = self.get_users_in_public_rooms() + shares_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + public_users = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) # No users share rooms self.assertEqual(public_users, []) - self.assertEqual(self._compress_shared(shares_private), set()) + self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set()) # Despite not sharing a room, search_all_users means we get a search # result. @@ -501,7 +715,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): } } ) - def test_prefer_local_users(self): + def test_prefer_local_users(self) -> None: """Tests that local users are shown higher in search results when user_directory.prefer_local_users is True. """ @@ -535,15 +749,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): local_users = [local_user_1, local_user_2, local_user_3] remote_users = [remote_user_1, remote_user_2, remote_user_3] - # Populate the user directory via background update - self._add_background_updates() - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) - # The local searching user searches for the term "user", which other users have # in their user id results = self.get_success( @@ -565,7 +770,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): room_id: str, room_version: RoomVersion, user_id: str, - ): + ) -> None: # Add a user to the room. builder = self.event_builder_factory.for_room_version( room_version, @@ -588,8 +793,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): class TestUserDirSearchDisabled(unittest.HomeserverTestCase): - user_id = "@test:test" - servlets = [ user_directory.register_servlets, room.register_servlets, @@ -597,7 +800,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["update_user_directory"] = True hs = self.setup_test_homeserver(config=config) @@ -606,19 +809,24 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): return hs - def test_disabling_room_list(self): + def test_disabling_room_list(self) -> None: self.config.userdirectory.user_directory_search_enabled = True - # First we create a room with another user so that user dir is non-empty - # for our user - self.helper.create_room_as(self.user_id) + # Create two users and put them in the same room. + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") u2 = self.register_user("user2", "pass") - room = self.helper.create_room_as(self.user_id) - self.helper.join(room, user=u2) + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) - # Assert user directory is not empty + # Each should see the other when searching the user directory. channel = self.make_request( - "POST", b"user_directory/search", b'{"search_term":"user2"}' + "POST", + b"user_directory/search", + b'{"search_term":"user2"}', + access_token=u1_token, ) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) > 0) @@ -626,7 +834,10 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): # Disable user directory and check search returns nothing self.config.userdirectory.user_directory_search_enabled = False channel = self.make_request( - "POST", b"user_directory/search", b'{"search_term":"user2"}' + "POST", + b"user_directory/search", + b'{"search_term":"user2"}', + access_token=u1_token, ) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) == 0) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index d9a8b077d3..638babae69 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py
@@ -226,7 +226,7 @@ class FederationClientTests(HomeserverTestCase): """Ensure that Synapse does not try to connect to blacklisted IPs""" # Set up the ip_range blacklist - self.hs.config.federation_ip_range_blacklist = IPSet( + self.hs.config.server.federation_ip_range_blacklist = IPSet( ["127.0.0.0/8", "fe80::/64"] ) self.reactor.lookups["internal"] = "127.0.0.1" diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index f73fcd684e..96f399b7ab 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py
@@ -198,3 +198,31 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(log["url"], "/_matrix/client/versions") self.assertEqual(log["protocol"], "1.1") self.assertEqual(log["user_agent"], "") + + def test_with_exception(self): + """ + The logging exception type & value should be added to the JSON response. + """ + handler = logging.StreamHandler(self.output) + handler.setFormatter(JsonFormatter()) + logger = self.get_logger(handler) + + try: + raise ValueError("That's wrong, you wally!") + except ValueError: + logger.exception("Hello there, %s!", "wally") + + log = self.get_log_line() + + # The terse logger should give us these keys. + expected_log_keys = [ + "log", + "level", + "namespace", + "exc_type", + "exc_value", + ] + self.assertCountEqual(log.keys(), expected_log_keys) + self.assertEqual(log["log"], "Hello there, wally!") + self.assertEqual(log["exc_type"], "ValueError") + self.assertEqual(log["exc_value"], "That's wrong, you wally!") diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 9d38974fba..e915dd5c7c 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py
@@ -25,6 +25,7 @@ from synapse.types import create_requester from tests.events.test_presence_router import send_presence_update, sync_presence from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.test_utils import simple_async_mock from tests.test_utils.event_injection import inject_member_event from tests.unittest import HomeserverTestCase, override_config from tests.utils import USE_POSTGRES_FOR_TESTS @@ -46,8 +47,12 @@ class ModuleApiTestCase(HomeserverTestCase): self.auth_handler = homeserver.get_auth_handler() def make_homeserver(self, reactor, clock): + # Mock out the calls over federation. + fed_transport_client = Mock(spec=["send_transaction"]) + fed_transport_client.send_transaction = simple_async_mock({}) + return self.setup_test_homeserver( - federation_transport_client=Mock(spec=["send_transaction"]), + federation_transport_client=fed_transport_client, ) def test_can_register_user(self): diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index c7555c26db..eac4664b41 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py
@@ -70,8 +70,16 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # databases objects are the same. self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool + # Normally we'd pass in the handler to `setup_test_homeserver`, which would + # eventually hit "Install @cache_in_self attributes" in tests/utils.py. + # Unfortunately our handler wants a reference to the homeserver. That leaves + # us with a chicken-and-egg problem. + # We can workaround this: create the homeserver first, create the handler + # and bodge it in after the fact. The bodging requires us to know the + # dirty details of how `cache_in_self` works. We politely ask mypy to + # ignore our dirty dealings. self.test_handler = self._build_replication_data_handler() - self.worker_hs._replication_data_handler = self.test_handler + self.worker_hs._replication_data_handler = self.test_handler # type: ignore[attr-defined] repl_handler = ReplicationCommandHandler(self.worker_hs) self.client = ClientReplicationStreamProtocol( @@ -240,7 +248,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): if self.hs.config.redis.redis_enabled: # Handle attempts to connect to fake redis server. self.reactor.add_tcp_client_callback( - b"localhost", + "localhost", 6379, self.connect_any_redis_attempts, ) @@ -315,12 +323,15 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): ) ) + # Copy the port into a new, non-Optional variable so mypy knows we're + # not going to reset `instance_loc` to `None` under its feet. See + # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions + port = instance_loc.port + self.reactor.add_tcp_client_callback( self.reactor.lookups[instance_loc.host], instance_loc.port, - lambda: self._handle_http_replication_attempt( - worker_hs, instance_loc.port - ), + lambda: self._handle_http_replication_attempt(worker_hs, port), ) store = worker_hs.get_datastore() @@ -424,7 +435,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): clients = self.reactor.tcpClients while clients: (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, b"localhost") + self.assertEqual(host, "localhost") self.assertEqual(port, 6379) client_protocol = client_factory.buildProtocol(None) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index ee3ae9cce4..6ed9e42173 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -59,7 +59,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver() - self.hs.config.registration_shared_secret = "shared" + self.hs.config.registration.registration_shared_secret = "shared" self.hs.get_media_repository = Mock() self.hs.get_deactivate_account_handler = Mock() @@ -71,7 +71,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): If there is no shared secret, registration through this method will be prevented. """ - self.hs.config.registration_shared_secret = None + self.hs.config.registration.registration_shared_secret = None channel = self.make_request("POST", self.url, b"{}") @@ -422,7 +422,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -1485,7 +1485,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -1522,7 +1522,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 9e9e953cf4..89d85b0a17 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py
@@ -470,13 +470,45 @@ class WhoamiTestCase(unittest.HomeserverTestCase): register.register_servlets, ] + def default_config(self): + config = super().default_config() + config["allow_guest_access"] = True + return config + def test_GET_whoami(self): device_id = "wouldgohere" user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test", device_id=device_id) - whoami = self.whoami(tok) - self.assertEqual(whoami, {"user_id": user_id, "device_id": device_id}) + whoami = self._whoami(tok) + self.assertEqual( + whoami, + { + "user_id": user_id, + "device_id": device_id, + # Unstable until MSC3069 enters spec + "org.matrix.msc3069.is_guest": False, + }, + ) + + def test_GET_whoami_guests(self): + channel = self.make_request( + b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}" + ) + tok = channel.json_body["access_token"] + user_id = channel.json_body["user_id"] + device_id = channel.json_body["device_id"] + + whoami = self._whoami(tok) + self.assertEqual( + whoami, + { + "user_id": user_id, + "device_id": device_id, + # Unstable until MSC3069 enters spec + "org.matrix.msc3069.is_guest": True, + }, + ) def test_GET_whoami_appservices(self): user_id = "@as:test" @@ -484,18 +516,25 @@ class WhoamiTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": user_id, "exclusive": True}]}, sender=user_id, ) self.hs.get_datastore().services_cache.append(appservice) - whoami = self.whoami(as_token) - self.assertEqual(whoami, {"user_id": user_id}) + whoami = self._whoami(as_token) + self.assertEqual( + whoami, + { + "user_id": user_id, + # Unstable until MSC3069 enters spec + "org.matrix.msc3069.is_guest": False, + }, + ) self.assertFalse(hasattr(whoami, "device_id")) - def whoami(self, tok): + def _whoami(self, tok): channel = self.make_request("GET", "account/whoami", {}, access_token=tok) self.assertEqual(channel.code, 200) return channel.json_body @@ -625,7 +664,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_add_email_if_disabled(self): """Test adding email to profile when doing so is disallowed""" - self.hs.config.enable_3pid_changes = False + self.hs.config.registration.enable_3pid_changes = False client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -695,7 +734,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_delete_email_if_disabled(self): """Test deleting an email from profile when disallowed""" - self.hs.config.enable_3pid_changes = False + self.hs.config.registration.enable_3pid_changes = False # Add a threepid self.get_success( diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index 422361b62a..b9e3602552 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py
@@ -55,7 +55,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) self.assertEqual( - self.config.default_room_version.identifier, + self.config.server.default_room_version.identifier, capabilities["m.room_versions"]["default"], ) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index ca2e8ff8ef..becb4e8dcc 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py
@@ -37,7 +37,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): return self.hs def test_3pid_lookup_disabled(self): - self.hs.config.enable_3pid_lookup = False + self.hs.config.registration.enable_3pid_lookup = False self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 371615a015..a63f04bd41 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py
@@ -94,9 +94,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver() - self.hs.config.enable_registration = True - self.hs.config.registrations_require_3pid = [] - self.hs.config.auto_join_rooms = [] + self.hs.config.registration.enable_registration = True + self.hs.config.registration.registrations_require_3pid = [] + self.hs.config.registration.auto_join_rooms = [] self.hs.config.captcha.enable_registration_captcha = False return self.hs @@ -1064,13 +1064,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): register.register_servlets, ] - def register_as_user(self, username): - self.make_request( - b"POST", - "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), - {"username": username}, - ) - def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver() @@ -1107,7 +1100,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): def test_login_appservice_user(self): """Test that an appservice user can use /login""" - self.register_as_user(AS_USER) + self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, @@ -1121,7 +1114,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): def test_login_appservice_user_bot(self): """Test that the appservice bot can use /login""" - self.register_as_user(AS_USER) + self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, @@ -1135,7 +1128,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): def test_login_appservice_wrong_user(self): """Test that non-as users cannot login with the as token""" - self.register_as_user(AS_USER) + self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, @@ -1149,7 +1142,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): def test_login_appservice_wrong_as(self): """Test that as users cannot login with wrong as token""" - self.register_as_user(AS_USER) + self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, @@ -1165,7 +1158,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): """Test that users must provide a token when using the appservice login method """ - self.register_as_user(AS_USER) + self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index 1d152352d1..56fe1a3d01 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py
@@ -50,7 +50,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): PUT to the status endpoint with use_presence enabled will call set_state on the presence handler. """ - self.hs.config.use_presence = True + self.hs.config.server.use_presence = True body = {"presence": "here", "status_msg": "beep boop"} channel = self.make_request( diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 72a5a11b46..66dcfc9f88 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py
@@ -50,7 +50,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", @@ -74,7 +74,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", @@ -147,7 +147,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_guest_registration(self): self.hs.config.key.macaroon_secret_key = "test" - self.hs.config.allow_guest_access = True + self.hs.config.registration.allow_guest_access = True channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") @@ -156,7 +156,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self): - self.hs.config.allow_guest_access = False + self.hs.config.registration.allow_guest_access = False channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 30bdaa9c27..376853fd65 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py
@@ -784,6 +784,30 @@ class RoomsCreateTestCase(RoomBase): # Check that do_3pid_invite wasn't called this time. self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids)) + def test_spam_checker_may_join_room(self): + """Tests that the user_may_join_room spam checker callback is correctly bypassed + when creating a new room. + """ + + async def user_may_join_room( + mxid: str, + room_id: str, + is_invite: bool, + ) -> bool: + return False + + join_mock = Mock(side_effect=user_may_join_room) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock) + + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + self.assertEquals(channel.code, 200, channel.json_body) + + self.assertEquals(join_mock.call_count, 0) + class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" @@ -975,6 +999,83 @@ class RoomInviteRatelimitTestCase(RoomBase): self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429) +class RoomJoinTestCase(RoomBase): + + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.user1 = self.register_user("thomas", "hackme") + self.tok1 = self.login("thomas", "hackme") + + self.user2 = self.register_user("teresa", "hackme") + self.tok2 = self.login("teresa", "hackme") + + self.room1 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) + self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) + self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) + + def test_spam_checker_may_join_room(self): + """Tests that the user_may_join_room spam checker callback is correctly called + and blocks room joins when needed. + """ + + # Register a dummy callback. Make it allow all room joins for now. + return_value = True + + async def user_may_join_room( + userid: str, + room_id: str, + is_invited: bool, + ) -> bool: + return return_value + + callback_mock = Mock(side_effect=user_may_join_room) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) + + # Join a first room, without being invited to it. + self.helper.join(self.room1, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room1, + False, + ), + ) + self.assertEquals( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Join a second room, this time with an invite for it. + self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1) + self.helper.join(self.room2, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room2, + True, + ), + ) + self.assertEquals( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Now make the callback deny all room joins, and check that a join actually fails. + return_value = False + self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) + + class RoomJoinRatelimitTestCase(RoomBase): user_id = "@sid1:red" @@ -2430,3 +2531,73 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): """An alias which does not point to the room raises a SynapseError.""" self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) + + +class ThreepidInviteTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("thomas", "hackme") + self.tok = self.login("thomas", "hackme") + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_threepid_invite_spamcheck(self): + # Mock a few functions to prevent the test from failing due to failing to talk to + # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we + # can check its call_count later on during the test. + make_invite_mock = Mock(return_value=make_awaitable(0)) + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock + self.hs.get_identity_handler().lookup_3pid = Mock( + return_value=make_awaitable(None), + ) + + # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it + # allow everything for now. + mock = Mock(return_value=make_awaitable(True)) + self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock) + + # Send a 3PID invite into the room and check that it succeeded. + email_to_invite = "teresa@example.com" + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEquals(channel.code, 200) + + # Check that the callback was called with the right params. + mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id) + + # Check that the call to send the invite was made. + make_invite_mock.assert_called_once() + + # Now change the return value of the callback to deny any invite and test that + # we can't send the invite. + mock.return_value = make_awaitable(False) + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEquals(channel.code, 403) + + # Also check that it stopped before calling _make_and_store_3pid_invite. + make_invite_mock.assert_called_once() diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 3075d3f288..71fa87ce92 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py
@@ -48,7 +48,7 @@ class RestHelper: def create_room_as( self, room_creator: Optional[str] = None, - is_public: bool = True, + is_public: Optional[bool] = None, room_version: Optional[str] = None, tok: Optional[str] = None, expect_code: int = 200, @@ -62,9 +62,10 @@ class RestHelper: Args: room_creator: The user ID to create the room with. - is_public: If True, the `visibility` parameter will be set to the - default (public). Otherwise, the `visibility` parameter will be set - to "private". + is_public: If True, the `visibility` parameter will be set to + "public". If False, it will be set to "private". If left + unspecified, the server will set it to an appropriate default + (which should be "private" as per the CS spec). room_version: The room version to create the room as. Defaults to Synapse's default room version. tok: The access token to use in the request. @@ -77,8 +78,8 @@ class RestHelper: self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" content = extra_content or {} - if not is_public: - content["visibility"] = "private" + if is_public is not None: + content["visibility"] = "public" if is_public else "private" if room_version: content["room_version"] = room_version if tok: diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 4d09b5d07e..8698135a76 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py
@@ -21,11 +21,13 @@ from twisted.internet.error import DNSLookupError from twisted.test.proto_helpers import AccumulatingProtocol from synapse.config.oembed import OEmbedEndpointConfig +from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS from synapse.util.stringutils import parse_and_validate_mxc_uri from tests import unittest from tests.server import FakeTransport from tests.test_utils import SMALL_PNG +from tests.utils import MockClock try: import lxml @@ -723,9 +725,107 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) + def test_oembed_autodiscovery(self): + """ + Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL. + 1. Request a preview of a URL which is not known to the oEmbed code. + 2. It returns HTML including a link to an oEmbed preview. + 3. The oEmbed preview is requested and returns a URL for an image. + 4. The image is requested for thumbnailing. + """ + # This is a little cheesy in that we use the www subdomain (which isn't the + # list of oEmbed patterns) to get "raw" HTML response. + self.lookups["www.twitter.com"] = [(IPv4Address, "10.1.2.3")] + self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] + self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] + + result = b""" + <link rel="alternate" type="application/json+oembed" + href="http://publish.twitter.com/oembed?url=http%3A%2F%2Fcdn.twitter.com%2Fmatrixdotorg%2Fstatus%2F12345&format=json" + title="matrixdotorg" /> + """ + + channel = self.make_request( + "GET", + "preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(result),) + + result + ) + + self.pump() + + # The oEmbed response. + result2 = { + "version": "1.0", + "type": "photo", + "url": "http://cdn.twitter.com/matrixdotorg", + } + oembed_content = json.dumps(result2).encode("utf-8") + + # Ensure a second request is made to the oEmbed URL. + client = self.reactor.tcpClients[1][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: application/json; charset="utf8"\r\n\r\n' + ) + % (len(oembed_content),) + + oembed_content + ) + + self.pump() + + # Ensure the URL is what was requested. + self.assertIn(b"/oembed?", server.data) + + # Ensure a third request is made to the photo URL. + client = self.reactor.tcpClients[2][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b"Content-Type: image/png\r\n\r\n" + ) + % (len(SMALL_PNG),) + + SMALL_PNG + ) + + self.pump() + + # Ensure the URL is what was requested. + self.assertIn(b"/matrixdotorg", server.data) + + self.assertEqual(channel.code, 200) + body = channel.json_body + self.assertEqual( + body["og:url"], "http://www.twitter.com/matrixdotorg/status/12345" + ) + self.assertTrue(body["og:image"].startswith("mxc://")) + self.assertEqual(body["og:image:height"], 1) + self.assertEqual(body["og:image:width"], 1) + self.assertEqual(body["og:image:type"], "image/png") + def _download_image(self): """Downloads an image into the URL cache. - Returns: A (host, media_id) tuple representing the MXC URI of the image. """ @@ -851,3 +951,32 @@ class URLPreviewTests(unittest.HomeserverTestCase): 404, "URL cache thumbnail was unexpectedly retrieved from a storage provider", ) + + def test_cache_expiry(self): + """Test that URL cache files and thumbnails are cleaned up properly on expiry.""" + self.preview_url.clock = MockClock() + + _host, media_id = self._download_image() + + file_path = self.preview_url.filepaths.url_cache_filepath(media_id) + file_dirs = self.preview_url.filepaths.url_cache_filepath_dirs_to_delete( + media_id + ) + thumbnail_dir = self.preview_url.filepaths.url_cache_thumbnail_directory( + media_id + ) + thumbnail_dirs = self.preview_url.filepaths.url_cache_thumbnail_dirs_to_delete( + media_id + ) + + self.assertTrue(os.path.isfile(file_path)) + self.assertTrue(os.path.isdir(thumbnail_dir)) + + self.preview_url.clock.advance_time_msec(IMAGE_CACHE_EXPIRY_MS + 1) + self.get_success(self.preview_url._expire_url_cache_data()) + + for path in [file_path] + file_dirs + [thumbnail_dir] + thumbnail_dirs: + self.assertFalse( + os.path.exists(path), + f"{os.path.relpath(path, self.media_store_path)} was not deleted", + ) diff --git a/tests/server.py b/tests/server.py
index 88dfa8058e..64645651ce 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -317,7 +317,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def __init__(self): self.threadpool = ThreadPool(self) - self._tcp_callbacks = {} + self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {} self._udp = [] self.lookups: Dict[str, str] = {} self._thread_callbacks: Deque[Callable[[], None]] = deque() @@ -355,7 +355,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def getThreadPool(self): return self.threadpool - def add_tcp_client_callback(self, host, port, callback): + def add_tcp_client_callback(self, host: str, port: int, callback: Callable): """Add a callback that will be invoked when we receive a connection attempt to the given IP/port using `connectTCP`. @@ -364,7 +364,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): """ self._tcp_callbacks[(host, port)] = callback - def connectTCP(self, host, port, factory, timeout=30, bindAddress=None): + def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None): """Fake L{IReactorTCP.connectTCP}.""" conn = super().connectTCP( @@ -475,7 +475,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): return server -def get_clock(): +def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: clock = ThreadedMemoryReactorClock() hs_clock = Clock(clock) return clock, hs_clock diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 7f25200a5d..36c495954f 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -346,7 +346,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): invites = [] # Register as many users as the MAU limit allows. - for i in range(self.hs.config.max_mau_value): + for i in range(self.hs.config.server.max_mau_value): localpart = "user%d" % i user_id = self.register_user(localpart, "password") tok = self.login(localpart, "password") diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index ffee707153..7496974da3 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py
@@ -79,12 +79,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): self.store.db_pool.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() # Make sure the background update filled in the room creator room_creator_after = self.get_success( diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index cf9748f218..f26d5acf9c 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py
@@ -126,7 +126,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.db_pool = database._db_pool self.engine = database.engine - db_config = hs.config.get_single_database() + db_config = hs.config.database.get_single_database() self.store = TestTransactionStore( database, make_conn(db_config, self.engine, "test"), hs ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 7cc5e621ba..a59c28f896 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py
@@ -66,12 +66,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Ugh, have to reset this flag self.store.db_pool.updates._all_done = False - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() def test_soft_failed_extremities_handled_correctly(self): """Test that extremities are correctly calculated in the presence of diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 3cc8038f1e..dada4f98c9 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py
@@ -242,12 +242,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_devices_last_seen_bg_update(self): # First make sure we have completed all updates. - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() user_id = "@user:id" device_id = "MY_DEVICE" @@ -311,12 +306,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.store.db_pool.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() # We should now get the correct result again result = self.get_success( @@ -337,12 +327,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_old_user_ips_pruned(self): # First make sure we have completed all updates. - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() user_id = "@user:id" device_id = "MY_DEVICE" diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 93136f0717..b31c5eb5ec 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py
@@ -578,12 +578,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): # Ugh, have to reset this flag self.store.db_pool.updates._all_done = False - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() # Test that the `has_auth_chain_index` has been set self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) @@ -619,12 +614,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): # Ugh, have to reset this flag self.store.db_pool.updates._all_done = False - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() # Test that the `has_auth_chain_index` has been set self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 944dbc34a2..d6b4cdd788 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py
@@ -51,7 +51,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)}) def test_initialise_reserved_users(self): - threepids = self.hs.config.mau_limits_reserved_threepids + threepids = self.hs.config.server.mau_limits_reserved_threepids # register three users, of which two have reserved 3pids, and a third # which is a support user. @@ -101,9 +101,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # XXX some of this is redundant. poking things into the config shouldn't # work, and in any case it's not obvious what we expect to happen when # we advance the reactor. - self.hs.config.max_mau_value = 0 + self.hs.config.server.max_mau_value = 0 self.reactor.advance(FORTY_DAYS) - self.hs.config.max_mau_value = 5 + self.hs.config.server.max_mau_value = 5 self.get_success(self.store.reap_monthly_active_users()) @@ -183,7 +183,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(d) count = self.get_success(self.store.get_monthly_active_count()) - self.assertEqual(count, self.hs.config.max_mau_value) + self.assertEqual(count, self.hs.config.server.max_mau_value) self.reactor.advance(FORTY_DAYS) @@ -199,7 +199,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_reap_monthly_active_users_reserved_users(self): """Tests that reaping correctly handles reaping where reserved users are present""" - threepids = self.hs.config.mau_limits_reserved_threepids + threepids = self.hs.config.server.mau_limits_reserved_threepids initial_users = len(threepids) reserved_user_number = initial_users - 1 for i in range(initial_users): @@ -234,7 +234,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(d) count = self.get_success(self.store.get_monthly_active_count()) - self.assertEqual(count, self.hs.config.max_mau_value) + self.assertEqual(count, self.hs.config.server.max_mau_value) def test_populate_monthly_users_is_guest(self): # Test that guest users are not added to mau list @@ -294,7 +294,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": user2_email}, ] - self.hs.config.mau_limits_reserved_threepids = threepids + self.hs.config.server.mau_limits_reserved_threepids = threepids d = self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index c72dc40510..2873e22ccf 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py
@@ -169,12 +169,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def test_can_rerun_update(self): # First make sure we have completed all updates. - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() # Now let's create a room, which will insert a membership user = UserID("alice", "test") @@ -197,9 +192,4 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): self.store.db_pool.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py
index 6ff3ebb137..ace82cbf42 100644 --- a/tests/storage/test_txn_limit.py +++ b/tests/storage/test_txn_limit.py
@@ -22,7 +22,7 @@ class SQLTransactionLimitTestCase(unittest.HomeserverTestCase): return self.setup_test_homeserver(db_txn_limit=1000) def test_config(self): - db_config = self.hs.config.get_single_database() + db_config = self.hs.config.database.get_single_database() self.assertEqual(db_config.config["txn_limit"], 1000) def test_select(self): diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 222e5d129d..9f483ad681 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py
@@ -11,7 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, List, Set, Tuple +from unittest import mock +from unittest.mock import Mock, patch +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventTypes, Membership, UserTypes +from synapse.appservice import ApplicationService +from synapse.rest import admin +from synapse.rest.client import login, register, room +from synapse.server import HomeServer +from synapse.storage import DataStore +from synapse.storage.roommember import ProfileInfo +from synapse.util import Clock + +from tests.test_utils.event_injection import inject_member_event from tests.unittest import HomeserverTestCase, override_config ALICE = "@alice:a" @@ -21,8 +36,376 @@ BOBBY = "@bobby:a" BELA = "@somenickname:a" +class GetUserDirectoryTables: + """Helper functions that we want to reuse in tests/handlers/test_user_directory.py""" + + def __init__(self, store: DataStore): + self.store = store + + def _compress_shared( + self, shared: List[Dict[str, str]] + ) -> Set[Tuple[str, str, str]]: + """ + Compress a list of users who share rooms dicts to a list of tuples. + """ + r = set() + for i in shared: + r.add((i["user_id"], i["other_user_id"], i["room_id"])) + return r + + async def get_users_in_public_rooms(self) -> List[Tuple[str, str]]: + """Fetch the entire `users_in_public_rooms` table. + + Returns a list of tuples (user_id, room_id) where room_id is public and + contains the user with the given id. + """ + r = await self.store.db_pool.simple_select_list( + "users_in_public_rooms", None, ("user_id", "room_id") + ) + + retval = [] + for i in r: + retval.append((i["user_id"], i["room_id"])) + return retval + + async def get_users_who_share_private_rooms(self) -> List[Dict[str, str]]: + """Fetch the entire `users_who_share_private_rooms` table. + + Returns a dict containing "user_id", "other_user_id" and "room_id" keys. + The dicts can be flattened to Tuples with the `_compress_shared` method. + (This seems a little awkward---maybe we could clean this up.) + """ + + return await self.store.db_pool.simple_select_list( + "users_who_share_private_rooms", + None, + ["user_id", "other_user_id", "room_id"], + ) + + async def get_users_in_user_directory(self) -> Set[str]: + """Fetch the set of users in the `user_directory` table. + + This is useful when checking we've correctly excluded users from the directory. + """ + result = await self.store.db_pool.simple_select_list( + "user_directory", + None, + ["user_id"], + ) + return {row["user_id"] for row in result} + + async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]: + """Fetch users and their profiles from the `user_directory` table. + + This is useful when we want to inspect display names and avatars. + It's almost the entire contents of the `user_directory` table: the only + thing missing is an unused room_id column. + """ + rows = await self.store.db_pool.simple_select_list( + "user_directory", + None, + ("user_id", "display_name", "avatar_url"), + ) + return { + row["user_id"]: ProfileInfo( + display_name=row["display_name"], avatar_url=row["avatar_url"] + ) + for row in rows + } + + +class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): + """Ensure that rebuilding the directory writes the correct data to the DB. + + See also tests/handlers/test_user_directory.py for similar checks. They + test the incremental updates, rather than the big rebuild. + """ + + servlets = [ + login.register_servlets, + admin.register_servlets, + room.register_servlets, + register.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.appservice = ApplicationService( + token="i_am_an_app_service", + hostname="test", + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", + ) + + mock_load_appservices = Mock(return_value=[self.appservice]) + with patch( + "synapse.storage.databases.main.appservice.load_appservices", + mock_load_appservices, + ): + hs = super().make_homeserver(reactor, clock) + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastore() + self.user_dir_helper = GetUserDirectoryTables(self.store) + + def _purge_and_rebuild_user_dir(self) -> None: + """Nuke the user directory tables, start the background process to + repopulate them, and wait for the process to complete. This allows us + to inspect the outcome of the background process alone, without any of + the other incremental updates. + """ + self.get_success(self.store.update_user_directory_stream_pos(None)) + self.get_success(self.store.delete_all_from_user_dir()) + + shares_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + public_users = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) + + # Nothing updated yet + self.assertEqual(shares_private, []) + self.assertEqual(public_users, []) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": "populate_user_directory_createtables", + "progress_json": "{}", + }, + ) + ) + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": "populate_user_directory_process_rooms", + "progress_json": "{}", + "depends_on": "populate_user_directory_createtables", + }, + ) + ) + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": "populate_user_directory_process_users", + "progress_json": "{}", + "depends_on": "populate_user_directory_process_rooms", + }, + ) + ) + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": "populate_user_directory_cleanup", + "progress_json": "{}", + "depends_on": "populate_user_directory_process_users", + }, + ) + ) + + self.wait_for_background_updates() + + def test_initial(self) -> None: + """ + The user directory's initial handler correctly updates the search tables. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + u3 = self.register_user("user3", "pass") + u3_token = self.login(u3, "pass") + + room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + private_room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) + self.helper.invite(private_room, src=u1, targ=u3, tok=u1_token) + self.helper.join(private_room, user=u3, tok=u3_token) + + # Do the initial population of the user directory via the background update + self._purge_and_rebuild_user_dir() + + shares_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + public_users = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) + + # User 1 and User 2 are in the same public room + self.assertEqual(set(public_users), {(u1, room), (u2, room)}) + + # User 1 and User 3 share private rooms + self.assertEqual( + self.user_dir_helper._compress_shared(shares_private), + {(u1, u3, private_room), (u3, u1, private_room)}, + ) + + # All three should have entries in the directory + users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) + self.assertEqual(users, {u1, u2, u3}) + + # The next three tests (test_population_excludes_*) all set up + # - A normal user included in the user dir + # - A public and private room created by that user + # - A user excluded from the room dir, belonging to both rooms + + # They match similar logic in handlers/test_user_directory.py But that tests + # updating the directory; this tests rebuilding it from scratch. + + def _create_rooms_and_inject_memberships( + self, creator: str, token: str, joiner: str + ) -> Tuple[str, str]: + """Create a public and private room as a normal user. + Then get the `joiner` into those rooms. + """ + public_room = self.helper.create_room_as( + creator, + is_public=True, + # See https://github.com/matrix-org/synapse/issues/10951 + extra_content={"visibility": "public"}, + tok=token, + ) + private_room = self.helper.create_room_as(creator, is_public=False, tok=token) + + # HACK: get the user into these rooms + self.get_success(inject_member_event(self.hs, public_room, joiner, "join")) + self.get_success(inject_member_event(self.hs, private_room, joiner, "join")) + + return public_room, private_room + + def _check_room_sharing_tables( + self, normal_user: str, public_room: str, private_room: str + ) -> None: + # After rebuilding the directory, we should only see the normal user. + users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) + self.assertEqual(users, {normal_user}) + in_public_rooms = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) + self.assertEqual(set(in_public_rooms), {(normal_user, public_room)}) + in_private_rooms = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + self.assertEqual(in_private_rooms, []) + + def test_population_excludes_support_user(self) -> None: + # Create a normal and support user. + user = self.register_user("user", "pass") + token = self.login(user, "pass") + support = "@support1:test" + self.get_success( + self.store.register_user( + user_id=support, password_hash=None, user_type=UserTypes.SUPPORT + ) + ) + + # Join the support user to rooms owned by the normal user. + public, private = self._create_rooms_and_inject_memberships( + user, token, support + ) + + # Rebuild the directory. + self._purge_and_rebuild_user_dir() + + # Check the support user is not in the directory. + self._check_room_sharing_tables(user, public, private) + + def test_population_excludes_deactivated_user(self) -> None: + user = self.register_user("naughty", "pass") + admin = self.register_user("admin", "pass", admin=True) + admin_token = self.login(admin, "pass") + + # Deactivate the user. + channel = self.make_request( + "PUT", + f"/_synapse/admin/v2/users/{user}", + access_token=admin_token, + content={"deactivated": True}, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["deactivated"], True) + + # Join the deactivated user to rooms owned by the admin. + # Is this something that could actually happen outside of a test? + public, private = self._create_rooms_and_inject_memberships( + admin, admin_token, user + ) + + # Rebuild the user dir. The deactivated user should be missing. + self._purge_and_rebuild_user_dir() + self._check_room_sharing_tables(admin, public, private) + + def test_population_excludes_appservice_user(self) -> None: + # Register an AS user. + user = self.register_user("user", "pass") + token = self.login(user, "pass") + as_user = self.register_appservice_user("as_user_potato", self.appservice.token) + + # Join the AS user to rooms owned by the normal user. + public, private = self._create_rooms_and_inject_memberships( + user, token, as_user + ) + + # Rebuild the directory. + self._purge_and_rebuild_user_dir() + + # Check the AS user is not in the directory. + self._check_room_sharing_tables(user, public, private) + + def test_population_conceals_private_nickname(self) -> None: + # Make a private room, and set a nickname within + user = self.register_user("aaaa", "pass") + user_token = self.login(user, "pass") + private_room = self.helper.create_room_as(user, is_public=False, tok=user_token) + self.helper.send_state( + private_room, + EventTypes.Member, + state_key=user, + body={"membership": Membership.JOIN, "displayname": "BBBB"}, + tok=user_token, + ) + + # Rebuild the user directory. Make the rescan of the `users` table a no-op + # so we only see the effect of scanning the `room_memberships` table. + async def mocked_process_users(*args: Any, **kwargs: Any) -> int: + await self.store.db_pool.updates._end_background_update( + "populate_user_directory_process_users" + ) + return 1 + + with mock.patch.dict( + self.store.db_pool.updates._background_update_handlers, + populate_user_directory_process_users=mocked_process_users, + ): + self._purge_and_rebuild_user_dir() + + # Local users are ignored by the scan over rooms + users = self.get_success(self.user_dir_helper.get_profiles_in_user_directory()) + self.assertEqual(users, {}) + + # Do a full rebuild including the scan over the `users` table. The local + # user should appear with their profile name. + self._purge_and_rebuild_user_dir() + users = self.get_success(self.user_dir_helper.get_profiles_in_user_directory()) + self.assertEqual( + users, {user: ProfileInfo(display_name="aaaa", avatar_url=None)} + ) + + class UserDirectoryStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() # alice and bob are both in !room_id. bobby is not but shares @@ -33,7 +416,7 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None)) self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))) - def test_search_user_dir(self): + def test_search_user_dir(self) -> None: # normally when alice searches the directory she should just find # bob because bobby doesn't share a room with her. r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10)) @@ -44,7 +427,7 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): ) @override_config({"user_directory": {"search_all_users": True}}) - def test_search_user_dir_all_users(self): + def test_search_user_dir_all_users(self) -> None: r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10)) self.assertFalse(r["limited"]) self.assertEqual(2, len(r["results"])) @@ -58,7 +441,7 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): ) @override_config({"user_directory": {"search_all_users": True}}) - def test_search_user_dir_stop_words(self): + def test_search_user_dir_stop_words(self) -> None: """Tests that a user can look up another user by searching for the start if its display name even if that name happens to be a common English word that would usually be ignored in full text searches. diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 1a4d078780..cf407c51cf 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py
@@ -38,21 +38,19 @@ class EventAuthTestCase(unittest.TestCase): } # creator should be able to send state - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _random_state_event(creator), auth_events, - do_sig_check=False, ) # joiner should not be able to send state self.assertRaises( AuthError, - event_auth.check, + event_auth.check_auth_rules_for_event, RoomVersions.V1, _random_state_event(joiner), auth_events, - do_sig_check=False, ) def test_state_default_level(self): @@ -77,19 +75,17 @@ class EventAuthTestCase(unittest.TestCase): # pleb should not be able to send state self.assertRaises( AuthError, - event_auth.check, + event_auth.check_auth_rules_for_event, RoomVersions.V1, _random_state_event(pleb), auth_events, - do_sig_check=False, ), # king should be able to send state - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _random_state_event(king), auth_events, - do_sig_check=False, ) def test_alias_event(self): @@ -102,37 +98,33 @@ class EventAuthTestCase(unittest.TestCase): } # creator should be able to send aliases - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _alias_event(creator), auth_events, - do_sig_check=False, ) # Reject an event with no state key. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _alias_event(creator, state_key=""), auth_events, - do_sig_check=False, ) # If the domain of the sender does not match the state key, reject. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _alias_event(creator, state_key="test.com"), auth_events, - do_sig_check=False, ) # Note that the member does *not* need to be in the room. - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _alias_event(other), auth_events, - do_sig_check=False, ) def test_msc2432_alias_event(self): @@ -145,34 +137,30 @@ class EventAuthTestCase(unittest.TestCase): } # creator should be able to send aliases - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _alias_event(creator), auth_events, - do_sig_check=False, ) # No particular checks are done on the state key. - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _alias_event(creator, state_key=""), auth_events, - do_sig_check=False, ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _alias_event(creator, state_key="test.com"), auth_events, - do_sig_check=False, ) # Per standard auth rules, the member must be in the room. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _alias_event(other), auth_events, - do_sig_check=False, ) def test_msc2209(self): @@ -192,20 +180,18 @@ class EventAuthTestCase(unittest.TestCase): } # pleb should be able to modify the notifications power level. - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _power_levels_event(pleb, {"notifications": {"room": 100}}), auth_events, - do_sig_check=False, ) # But an MSC2209 room rejects this change. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _power_levels_event(pleb, {"notifications": {"room": 100}}), auth_events, - do_sig_check=False, ) def test_join_rules_public(self): @@ -222,59 +208,53 @@ class EventAuthTestCase(unittest.TestCase): } # Check join. - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user cannot be force-joined to a room. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _member_event(pleb, "join", sender=creator), auth_events, - do_sig_check=False, ) # Banned should be rejected. auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user who left can re-join. auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user can send a join if they're in the room. auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user can accept an invite. auth_events[("m.room.member", pleb)] = _member_event( pleb, "invite", sender=creator ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) def test_join_rules_invite(self): @@ -292,60 +272,54 @@ class EventAuthTestCase(unittest.TestCase): # A join without an invite is rejected. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user cannot be force-joined to a room. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _member_event(pleb, "join", sender=creator), auth_events, - do_sig_check=False, ) # Banned should be rejected. auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user who left cannot re-join. auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user can send a join if they're in the room. auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user can accept an invite. auth_events[("m.room.member", pleb)] = _member_event( pleb, "invite", sender=creator ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) def test_join_rules_msc3083_restricted(self): @@ -370,11 +344,10 @@ class EventAuthTestCase(unittest.TestCase): # Older room versions don't understand this join rule with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A properly formatted join event should work. @@ -384,11 +357,10 @@ class EventAuthTestCase(unittest.TestCase): EventContentFields.AUTHORISING_USER: "@creator:example.com" }, ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, authorised_join_event, auth_events, - do_sig_check=False, ) # A join issued by a specific user works (i.e. the power level checks @@ -400,7 +372,7 @@ class EventAuthTestCase(unittest.TestCase): pl_auth_events[("m.room.member", "@inviter:foo.test")] = _join_event( "@inviter:foo.test" ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event( pleb, @@ -409,16 +381,14 @@ class EventAuthTestCase(unittest.TestCase): }, ), pl_auth_events, - do_sig_check=False, ) # A join which is missing an authorised server is rejected. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event(pleb), auth_events, - do_sig_check=False, ) # An join authorised by a user who is not in the room is rejected. @@ -427,7 +397,7 @@ class EventAuthTestCase(unittest.TestCase): creator, {"invite": 100, "users": {"@other:example.com": 150}} ) with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event( pleb, @@ -436,13 +406,12 @@ class EventAuthTestCase(unittest.TestCase): }, ), auth_events, - do_sig_check=False, ) # A user cannot be force-joined to a room. (This uses an event which # *would* be valid, but is sent be a different user.) with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _member_event( pleb, @@ -453,36 +422,32 @@ class EventAuthTestCase(unittest.TestCase): }, ), auth_events, - do_sig_check=False, ) # Banned should be rejected. auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, authorised_join_event, auth_events, - do_sig_check=False, ) # A user who left can re-join. auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, authorised_join_event, auth_events, - do_sig_check=False, ) # A user can send a join if they're in the room. (This doesn't need to # be authorised since the user is already joined.) auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user can accept an invite. (This doesn't need to be authorised since @@ -490,11 +455,10 @@ class EventAuthTestCase(unittest.TestCase): auth_events[("m.room.member", pleb)] = _member_event( pleb, "invite", sender=creator ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event(pleb), auth_events, - do_sig_check=False, ) diff --git a/tests/test_federation.py b/tests/test_federation.py
index c51e018da1..24fc77d7a7 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py
@@ -82,7 +82,6 @@ class MessageAcceptTests(unittest.HomeserverTestCase): event, context, state=None, - claimed_auth_event_map=None, backfilled=False, ): return context diff --git a/tests/test_mau.py b/tests/test_mau.py
index 66111eb367..c683c8937e 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py
@@ -13,11 +13,11 @@ # limitations under the License. """Tests REST events for /rooms paths.""" - +import synapse.rest.admin from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.appservice import ApplicationService -from synapse.rest.client import register, sync +from synapse.rest.client import login, profile, register, sync from tests import unittest from tests.unittest import override_config @@ -26,7 +26,13 @@ from tests.utils import default_config class TestMauLimit(unittest.HomeserverTestCase): - servlets = [register.register_servlets, sync.register_servlets] + servlets = [ + register.register_servlets, + sync.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + profile.register_servlets, + login.register_servlets, + ] def default_config(self): config = default_config("test") @@ -165,7 +171,7 @@ class TestMauLimit(unittest.HomeserverTestCase): @override_config({"mau_trial_days": 1}) def test_trial_users_cant_come_back(self): - self.hs.config.mau_trial_days = 1 + self.hs.config.server.mau_trial_days = 1 # We should be able to register more than the limit initially token1 = self.create_user("kermit1") @@ -229,6 +235,31 @@ class TestMauLimit(unittest.HomeserverTestCase): self.reactor.advance(100) self.assertEqual(2, self.successResultOf(count)) + def test_deactivated_users_dont_count_towards_mau(self): + user1 = self.register_user("madonna", "password") + self.register_user("prince", "password2") + self.register_user("frodo", "onering", True) + + token1 = self.login("madonna", "password") + token2 = self.login("prince", "password2") + admin_token = self.login("frodo", "onering") + + self.do_sync_for_user(token1) + self.do_sync_for_user(token2) + + # Check that mau count is what we expect + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 2) + + # Deactivate user1 + url = "/_synapse/admin/v1/deactivate/%s" % user1 + channel = self.make_request("POST", url, access_token=admin_token) + self.assertIn("success", channel.json_body["id_server_unbind_result"]) + + # Check that deactivated user is no longer counted + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 1) + def create_user(self, localpart, token=None, appservice=False): request_data = { "username": localpart, diff --git a/tests/test_preview.py b/tests/test_preview.py
index 48e792b55b..09e017b4d9 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py
@@ -13,7 +13,8 @@ # limitations under the License. from synapse.rest.media.v1.preview_url_resource import ( - decode_and_calc_og, + _calc_og, + decode_body, get_html_media_encoding, summarize_paragraphs, ) @@ -158,7 +159,8 @@ class CalcOgTestCase(unittest.TestCase): </html> """ - og = decode_and_calc_og(html, "http://example.com/test.html") + tree = decode_body(html) + og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -173,7 +175,8 @@ class CalcOgTestCase(unittest.TestCase): </html> """ - og = decode_and_calc_og(html, "http://example.com/test.html") + tree = decode_body(html) + og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -191,7 +194,8 @@ class CalcOgTestCase(unittest.TestCase): </html> """ - og = decode_and_calc_og(html, "http://example.com/test.html") + tree = decode_body(html) + og = _calc_og(tree, "http://example.com/test.html") self.assertEqual( og, @@ -212,7 +216,8 @@ class CalcOgTestCase(unittest.TestCase): </html> """ - og = decode_and_calc_og(html, "http://example.com/test.html") + tree = decode_body(html) + og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -225,7 +230,8 @@ class CalcOgTestCase(unittest.TestCase): </html> """ - og = decode_and_calc_og(html, "http://example.com/test.html") + tree = decode_body(html) + og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) @@ -239,7 +245,8 @@ class CalcOgTestCase(unittest.TestCase): </html> """ - og = decode_and_calc_og(html, "http://example.com/test.html") + tree = decode_body(html) + og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) @@ -253,21 +260,22 @@ class CalcOgTestCase(unittest.TestCase): </html> """ - og = decode_and_calc_og(html, "http://example.com/test.html") + tree = decode_body(html) + og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) def test_empty(self): """Test a body with no data in it.""" html = b"" - og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEqual(og, {}) + tree = decode_body(html) + self.assertIsNone(tree) def test_no_tree(self): """A valid body with no tree in it.""" html = b"\x00" - og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEqual(og, {}) + tree = decode_body(html) + self.assertIsNone(tree) def test_invalid_encoding(self): """An invalid character encoding should be ignored and treated as UTF-8, if possible.""" @@ -279,9 +287,8 @@ class CalcOgTestCase(unittest.TestCase): </body> </html> """ - og = decode_and_calc_og( - html, "http://example.com/test.html", "invalid-encoding" - ) + tree = decode_body(html, "invalid-encoding") + og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_invalid_encoding2(self): @@ -295,7 +302,8 @@ class CalcOgTestCase(unittest.TestCase): </body> </html> """ - og = decode_and_calc_og(html, "http://example.com/test.html") + tree = decode_body(html) + og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) diff --git a/tests/unittest.py b/tests/unittest.py
index 7a6f5954d0..81c1a9e9d2 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import inspect import logging import secrets import time -from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union from unittest.mock import Mock, patch from canonicaljson import json @@ -28,6 +28,7 @@ from canonicaljson import json from twisted.internet.defer import Deferred, ensureDeferred, succeed from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool +from twisted.test.proto_helpers import MemoryReactor from twisted.trial import unittest from twisted.web.resource import Resource @@ -46,6 +47,7 @@ from synapse.logging.context import ( ) from synapse.server import HomeServer from synapse.types import UserID, create_requester +from synapse.util import Clock from synapse.util.httpresourcetree import create_resource_tree from synapse.util.ratelimitutils import FederationRateLimiter @@ -232,7 +234,7 @@ class HomeserverTestCase(TestCase): # Honour the `use_frozen_dicts` config option. We have to do this # manually because this is taken care of in the app `start` code, which # we don't run. Plus we want to reset it on tearDown. - events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts + events.USE_FROZEN_DICTS = self.hs.config.server.use_frozen_dicts if self.hs is None: raise Exception("No homeserver returned from make_homeserver.") @@ -315,6 +317,15 @@ class HomeserverTestCase(TestCase): self.reactor.advance(0.01) time.sleep(0.01) + def wait_for_background_updates(self) -> None: + """Block until all background database updates have completed.""" + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + def make_homeserver(self, reactor, clock): """ Make and return a homeserver. @@ -371,7 +382,7 @@ class HomeserverTestCase(TestCase): return config - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): """ Prepare for the test. This involves things like mocking out parts of the homeserver, or building test data common across the whole test @@ -447,7 +458,7 @@ class HomeserverTestCase(TestCase): client_ip, ) - def setup_test_homeserver(self, *args, **kwargs): + def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer: """ Set up the test homeserver, meant to be called by the overridable make_homeserver. It automatically passes through the test class's @@ -558,7 +569,7 @@ class HomeserverTestCase(TestCase): Returns: The MXID of the new user. """ - self.hs.config.registration_shared_secret = "shared" + self.hs.config.registration.registration_shared_secret = "shared" # Create the user channel = self.make_request("GET", "/_synapse/admin/v1/register") @@ -594,6 +605,35 @@ class HomeserverTestCase(TestCase): user_id = channel.json_body["user_id"] return user_id + def register_appservice_user( + self, + username: str, + appservice_token: str, + ) -> str: + """Register an appservice user as an application service. + Requires the client-facing registration API be registered. + + Args: + username: the user to be registered by an application service. + Should be a full username, i.e. ""@localpart:hostname" as opposed to just "localpart" + appservice_token: the acccess token for that application service. + + Raises: if the request to '/register' does not return 200 OK. + + Returns: the MXID of the new user. + """ + channel = self.make_request( + "POST", + "/_matrix/client/r0/register", + { + "username": username, + "type": "m.login.application_service", + }, + access_token=appservice_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + return channel.json_body["user_id"] + def login( self, username,