diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index cccff7af26..3aa9ba3c43 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -217,7 +217,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
- location=self.hs.config.server_name,
+ location=self.hs.config.server.server_name,
identifier="key",
key=self.hs.config.key.macaroon_secret_key,
)
@@ -239,7 +239,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
- location=self.hs.config.server_name,
+ location=self.hs.config.server.server_name,
identifier="key",
key=self.hs.config.key.macaroon_secret_key,
)
@@ -268,7 +268,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
- self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.value.code, 403)
@@ -303,7 +303,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
"abcd",
- self.hs.config.server_name,
+ self.hs.config.server.server_name,
id="1234",
namespaces={
"users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
@@ -332,7 +332,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
"abcd",
- self.hs.config.server_name,
+ self.hs.config.server.server_name,
id="1234",
namespaces={
"users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
@@ -372,7 +372,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
- self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.value.code, 403)
@@ -387,7 +387,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
- self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.value.code, 403)
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index a2b5ed2030..55f0899bae 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -24,7 +24,7 @@ from synapse.appservice.scheduler import (
from synapse.logging.context import make_deferred_yieldable
from tests import unittest
-from tests.test_utils import make_awaitable
+from tests.test_utils import simple_async_mock
from ..utils import MockClock
@@ -49,11 +49,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
txn = Mock(id=txn_id, service=service, events=events)
# mock methods
- self.store.get_appservice_state = Mock(
- return_value=defer.succeed(ApplicationServiceState.UP)
- )
- txn.send = Mock(return_value=make_awaitable(True))
- self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
+ self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP)
+ txn.send = simple_async_mock(True)
+ txn.complete = simple_async_mock(True)
+ self.store.create_appservice_txn = simple_async_mock(txn)
# actual call
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@@ -71,10 +70,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
events = [Mock(), Mock()]
txn = Mock(id="idhere", service=service, events=events)
- self.store.get_appservice_state = Mock(
- return_value=defer.succeed(ApplicationServiceState.DOWN)
+ self.store.get_appservice_state = simple_async_mock(
+ ApplicationServiceState.DOWN
)
- self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
+ self.store.create_appservice_txn = simple_async_mock(txn)
# actual call
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@@ -94,12 +93,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
txn = Mock(id=txn_id, service=service, events=events)
# mock methods
- self.store.get_appservice_state = Mock(
- return_value=defer.succeed(ApplicationServiceState.UP)
- )
- self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
- txn.send = Mock(return_value=make_awaitable(False)) # fails to send
- self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
+ self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP)
+ self.store.set_appservice_state = simple_async_mock(True)
+ txn.send = simple_async_mock(False) # fails to send
+ self.store.create_appservice_txn = simple_async_mock(txn)
# actual call
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@@ -122,7 +119,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.as_api = Mock()
self.store = Mock()
self.service = Mock()
- self.callback = Mock()
+ self.callback = simple_async_mock()
self.recoverer = _Recoverer(
clock=self.clock,
as_api=self.as_api,
@@ -144,8 +141,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover()
# shouldn't have called anything prior to waiting for exp backoff
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
- txn.send = Mock(return_value=make_awaitable(True))
- txn.complete.return_value = make_awaitable(None)
+ txn.send = simple_async_mock(True)
+ txn.complete = simple_async_mock(None)
# wait for exp backoff
self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count)
@@ -170,8 +167,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover()
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
- txn.send = Mock(return_value=make_awaitable(False))
- txn.complete.return_value = make_awaitable(None)
+ txn.send = simple_async_mock(False)
+ txn.complete = simple_async_mock(None)
self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count)
@@ -184,7 +181,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.assertEquals(3, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count)
self.assertEquals(0, self.callback.call_count)
- txn.send = Mock(return_value=make_awaitable(True)) # successfully send the txn
+ txn.send = simple_async_mock(True) # successfully send the txn
pop_txn = True # returns the txn the first time, then no more.
self.clock.advance_time(16)
self.assertEquals(1, txn.send.call_count) # new mock reset call count
@@ -195,6 +192,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def setUp(self):
self.txn_ctrl = Mock()
+ self.txn_ctrl.send = simple_async_mock()
self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
def test_send_single_event_no_queue(self):
diff --git a/tests/config/test_base.py b/tests/config/test_base.py
index baa5313fb3..6a52f862f4 100644
--- a/tests/config/test_base.py
+++ b/tests/config/test_base.py
@@ -14,23 +14,28 @@
import os.path
import tempfile
+from unittest.mock import Mock
from synapse.config import ConfigError
+from synapse.config._base import Config
from synapse.util.stringutils import random_string
from tests import unittest
-class BaseConfigTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
- self.hs = hs
+class BaseConfigTestCase(unittest.TestCase):
+ def setUp(self):
+ # The root object needs a server property with a public_baseurl.
+ root = Mock()
+ root.server.public_baseurl = "http://test"
+ self.config = Config(root)
def test_loading_missing_templates(self):
# Use a temporary directory that exists on the system, but that isn't likely to
# contain template files
with tempfile.TemporaryDirectory() as tmp_dir:
# Attempt to load an HTML template from our custom template directory
- template = self.hs.config.read_templates(["sso_error.html"], (tmp_dir,))[0]
+ template = self.config.read_templates(["sso_error.html"], (tmp_dir,))[0]
# If no errors, we should've gotten the default template instead
@@ -60,7 +65,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase):
# Attempt to load the template from our custom template directory
template = (
- self.hs.config.read_templates([template_filename], (tmp_dir,))
+ self.config.read_templates([template_filename], (tmp_dir,))
)[0]
# Render the template
@@ -97,7 +102,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase):
# Retrieve the template.
template = (
- self.hs.config.read_templates(
+ self.config.read_templates(
[template_filename],
(td.name for td in tempdirs),
)
@@ -118,7 +123,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase):
# Retrieve the template.
template = (
- self.hs.config.read_templates(
+ self.config.read_templates(
[other_template_name],
(td.name for td in tempdirs),
)
@@ -134,6 +139,6 @@ class BaseConfigTestCase(unittest.HomeserverTestCase):
def test_loading_template_from_nonexistent_custom_directory(self):
with self.assertRaises(ConfigError):
- self.hs.config.read_templates(
+ self.config.read_templates(
["some_filename.html"], ("a_nonexistent_directory",)
)
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
index 857d9cd096..79d417568d 100644
--- a/tests/config/test_cache.py
+++ b/tests/config/test_cache.py
@@ -12,59 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.config._base import Config, RootConfig
+from unittest.mock import patch
+
from synapse.config.cache import CacheConfig, add_resizable_cache
from synapse.util.caches.lrucache import LruCache
from tests.unittest import TestCase
-class FakeServer(Config):
- section = "server"
-
-
-class TestConfig(RootConfig):
- config_classes = [FakeServer, CacheConfig]
-
-
+# Patch the global _CACHES so that each test runs against its own state.
+@patch("synapse.config.cache._CACHES", new_callable=dict)
class CacheConfigTests(TestCase):
def setUp(self):
# Reset caches before each test
- TestConfig().caches.reset()
+ self.config = CacheConfig()
+
+ def tearDown(self):
+ self.config.reset()
- def test_individual_caches_from_environ(self):
+ def test_individual_caches_from_environ(self, _caches):
"""
Individual cache factors will be loaded from the environment.
"""
config = {}
- t = TestConfig()
- t.caches._environ = {
+ self.config._environ = {
"SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
"SYNAPSE_NOT_CACHE": "BLAH",
}
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0})
+ self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0})
- def test_config_overrides_environ(self):
+ def test_config_overrides_environ(self, _caches):
"""
Individual cache factors defined in the environment will take precedence
over those in the config.
"""
config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}}
- t = TestConfig()
- t.caches._environ = {
+ self.config._environ = {
"SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
"SYNAPSE_CACHE_FACTOR_FOO": 1,
}
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config, config_dir_path="", data_dir_path="")
self.assertEqual(
- dict(t.caches.cache_factors),
+ dict(self.config.cache_factors),
{"foo": 1.0, "bar": 3.0, "something_or_other": 2.0},
)
- def test_individual_instantiated_before_config_load(self):
+ def test_individual_instantiated_before_config_load(self, _caches):
"""
If a cache is instantiated before the config is read, it will be given
the default cache size in the interim, and then resized once the config
@@ -76,26 +72,24 @@ class CacheConfigTests(TestCase):
self.assertEqual(cache.max_size, 50)
config = {"caches": {"per_cache_factors": {"foo": 3}}}
- t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config)
self.assertEqual(cache.max_size, 300)
- def test_individual_instantiated_after_config_load(self):
+ def test_individual_instantiated_after_config_load(self, _caches):
"""
If a cache is instantiated after the config is read, it will be
immediately resized to the correct size given the per_cache_factor if
there is one.
"""
config = {"caches": {"per_cache_factors": {"foo": 2}}}
- t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config, config_dir_path="", data_dir_path="")
cache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 200)
- def test_global_instantiated_before_config_load(self):
+ def test_global_instantiated_before_config_load(self, _caches):
"""
If a cache is instantiated before the config is read, it will be given
the default cache size in the interim, and then resized to the new
@@ -106,26 +100,24 @@ class CacheConfigTests(TestCase):
self.assertEqual(cache.max_size, 50)
config = {"caches": {"global_factor": 4}}
- t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config, config_dir_path="", data_dir_path="")
self.assertEqual(cache.max_size, 400)
- def test_global_instantiated_after_config_load(self):
+ def test_global_instantiated_after_config_load(self, _caches):
"""
If a cache is instantiated after the config is read, it will be
immediately resized to the correct size given the global factor if there
is no per-cache factor.
"""
config = {"caches": {"global_factor": 1.5}}
- t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config, config_dir_path="", data_dir_path="")
cache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 150)
- def test_cache_with_asterisk_in_name(self):
+ def test_cache_with_asterisk_in_name(self, _caches):
"""Some caches have asterisks in their name, test that they are set correctly."""
config = {
@@ -133,12 +125,11 @@ class CacheConfigTests(TestCase):
"per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2}
}
}
- t = TestConfig()
- t.caches._environ = {
+ self.config._environ = {
"SYNAPSE_CACHE_FACTOR_CACHE_A": "2",
"SYNAPSE_CACHE_FACTOR_CACHE_B": 3,
}
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config, config_dir_path="", data_dir_path="")
cache_a = LruCache(100)
add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
@@ -152,17 +143,16 @@ class CacheConfigTests(TestCase):
add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
self.assertEqual(cache_c.max_size, 200)
- def test_apply_cache_factor_from_config(self):
+ def test_apply_cache_factor_from_config(self, _caches):
"""Caches can disable applying cache factor updates, mainly used by
event cache size.
"""
config = {"caches": {"event_cache_size": "10k"}}
- t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config, config_dir_path="", data_dir_path="")
cache = LruCache(
- max_size=t.caches.event_cache_size,
+ max_size=self.config.event_cache_size,
apply_cache_factor_from_config=False,
)
add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor)
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index ef6c2beec7..59635de205 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -49,7 +49,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
config = HomeServerConfig.load_config("", ["-c", self.file])
self.assertTrue(
- hasattr(config, "macaroon_secret_key"),
+ hasattr(config.key, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key",
)
if len(config.key.macaroon_secret_key) < 5:
@@ -60,7 +60,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
self.assertTrue(
- hasattr(config, "macaroon_secret_key"),
+ hasattr(config.key, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key",
)
if len(config.key.macaroon_secret_key) < 5:
@@ -74,8 +74,12 @@ class ConfigLoadingTestCase(unittest.TestCase):
config1 = HomeServerConfig.load_config("", ["-c", self.file])
config2 = HomeServerConfig.load_config("", ["-c", self.file])
config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
- self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key)
- self.assertEqual(config1.macaroon_secret_key, config3.macaroon_secret_key)
+ self.assertEqual(
+ config1.key.macaroon_secret_key, config2.key.macaroon_secret_key
+ )
+ self.assertEqual(
+ config1.key.macaroon_secret_key, config3.key.macaroon_secret_key
+ )
def test_disable_registration(self):
self.generate_config()
@@ -84,16 +88,16 @@ class ConfigLoadingTestCase(unittest.TestCase):
)
# Check that disable_registration clobbers enable_registration.
config = HomeServerConfig.load_config("", ["-c", self.file])
- self.assertFalse(config.enable_registration)
+ self.assertFalse(config.registration.enable_registration)
config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
- self.assertFalse(config.enable_registration)
+ self.assertFalse(config.registration.enable_registration)
# Check that either config value is clobbered by the command line.
config = HomeServerConfig.load_or_generate_config(
"", ["-c", self.file, "--enable-registration"]
)
- self.assertTrue(config.enable_registration)
+ self.assertTrue(config.registration.enable_registration)
def test_stats_enabled(self):
self.generate_config_and_remove_lines_containing("enable_metrics")
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index b6bc1876b5..9ba5781573 100644
--- a/tests/config/test_tls.py
+++ b/tests/config/test_tls.py
@@ -42,9 +42,9 @@ class TLSConfigTests(TestCase):
"""
config = {}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(t.federation_client_minimum_tls_version, "1")
+ self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
def test_tls_client_minimum_set(self):
"""
@@ -52,29 +52,29 @@ class TLSConfigTests(TestCase):
"""
config = {"federation_client_minimum_tls_version": 1}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(t.federation_client_minimum_tls_version, "1")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
config = {"federation_client_minimum_tls_version": 1.1}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(t.federation_client_minimum_tls_version, "1.1")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.1")
config = {"federation_client_minimum_tls_version": 1.2}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(t.federation_client_minimum_tls_version, "1.2")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2")
# Also test a string version
config = {"federation_client_minimum_tls_version": "1"}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(t.federation_client_minimum_tls_version, "1")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
config = {"federation_client_minimum_tls_version": "1.2"}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(t.federation_client_minimum_tls_version, "1.2")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2")
def test_tls_client_minimum_1_point_3_missing(self):
"""
@@ -91,7 +91,7 @@ class TLSConfigTests(TestCase):
config = {"federation_client_minimum_tls_version": 1.3}
t = TestConfig()
with self.assertRaises(ConfigError) as e:
- t.read_config(config, config_dir_path="", data_dir_path="")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
self.assertEqual(
e.exception.args[0],
(
@@ -112,8 +112,8 @@ class TLSConfigTests(TestCase):
config = {"federation_client_minimum_tls_version": 1.3}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(t.federation_client_minimum_tls_version, "1.3")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3")
def test_tls_client_minimum_set_passed_through_1_2(self):
"""
@@ -121,7 +121,7 @@ class TLSConfigTests(TestCase):
"""
config = {"federation_client_minimum_tls_version": 1.2}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
cf = FederationPolicyForHTTPS(t)
options = _get_ssl_context_options(cf._verify_ssl_context)
@@ -137,7 +137,7 @@ class TLSConfigTests(TestCase):
"""
config = {"federation_client_minimum_tls_version": 1}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
cf = FederationPolicyForHTTPS(t)
options = _get_ssl_context_options(cf._verify_ssl_context)
@@ -159,7 +159,7 @@ class TLSConfigTests(TestCase):
}
t = TestConfig()
e = self.assertRaises(
- ConfigError, t.read_config, config, config_dir_path="", data_dir_path=""
+ ConfigError, t.tls.read_config, config, config_dir_path="", data_dir_path=""
)
self.assertIn("IDNA domain names", str(e))
@@ -174,7 +174,7 @@ class TLSConfigTests(TestCase):
]
}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
cf = FederationPolicyForHTTPS(t)
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 3b3866bff8..3deb14c308 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -26,6 +26,7 @@ from synapse.rest.client import login, presence, room
from synapse.types import JsonDict, StreamToken, create_requester
from tests.handlers.test_sync import generate_sync_config
+from tests.test_utils import simple_async_mock
from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
@@ -133,8 +134,12 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
]
def make_homeserver(self, reactor, clock):
+ # Mock out the calls over federation.
+ fed_transport_client = Mock(spec=["send_transaction"])
+ fed_transport_client.send_transaction = simple_async_mock({})
+
hs = self.setup_test_homeserver(
- federation_transport_client=Mock(spec=["send_transaction"]),
+ federation_transport_client=fed_transport_client,
)
# Load the modules into the homeserver
module_api = hs.get_module_api()
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 65b18fbd7a..b457dad6d2 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -336,7 +336,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
recovery
"""
mock_send_txn = self.hs.get_federation_transport_client().send_transaction
- mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+ mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices
u1 = self.register_user("user", "pass")
@@ -376,7 +376,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
This case tests the behaviour when the server has never been reachable.
"""
mock_send_txn = self.hs.get_federation_transport_client().send_transaction
- mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+ mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices
u1 = self.register_user("user", "pass")
@@ -429,7 +429,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# now the server goes offline
mock_send_txn = self.hs.get_federation_transport_client().send_transaction
- mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+ mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
self.login("user", "pass", device_id="D2")
self.login("user", "pass", device_id="D3")
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 0b60cc4261..03e1e11f49 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -120,7 +120,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(
channel.json_body["room_version"],
- self.hs.config.default_room_version.identifier,
+ self.hs.config.server.default_room_version.identifier,
)
members = set(
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 57cc3e2646..c153018fd8 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -110,7 +110,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
def test_set_my_name_if_disabled(self):
- self.hs.config.enable_set_displayname = False
+ self.hs.config.registration.enable_set_displayname = False
# Setting displayname for the first time is allowed
self.get_success(
@@ -225,7 +225,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
def test_set_my_avatar_if_disabled(self):
- self.hs.config.enable_set_avatar_url = False
+ self.hs.config.registration.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
self.get_success(
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index d3efb67e3e..db691c4c1c 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -16,7 +16,12 @@ from unittest.mock import Mock
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
-from synapse.api.errors import Codes, ResourceLimitError, SynapseError
+from synapse.api.errors import (
+ CodeMessageException,
+ Codes,
+ ResourceLimitError,
+ SynapseError,
+)
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, RoomID, UserID, create_requester
@@ -120,14 +125,24 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
hs_config = self.default_config()
# some of the tests rely on us having a user consent version
- hs_config["user_consent"] = {
- "version": "test_consent_version",
- "template_dir": ".",
- }
+ hs_config.setdefault("user_consent", {}).update(
+ {
+ "version": "test_consent_version",
+ "template_dir": ".",
+ }
+ )
hs_config["max_mau_value"] = 50
hs_config["limit_usage_by_mau"] = True
- hs = self.setup_test_homeserver(config=hs_config)
+ # Don't attempt to reach out over federation.
+ self.mock_federation_client = Mock()
+ self.mock_federation_client.make_query.side_effect = CodeMessageException(
+ 500, ""
+ )
+
+ hs = self.setup_test_homeserver(
+ config=hs_config, federation_client=self.mock_federation_client
+ )
load_legacy_spam_checkers(hs)
@@ -138,9 +153,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor, clock, hs):
- self.mock_distributor = Mock()
- self.mock_distributor.declare("registered_user")
- self.mock_captcha_client = Mock()
self.handler = self.hs.get_registration_handler()
self.store = self.hs.get_datastore()
self.lots_of_users = 100
@@ -174,21 +186,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEquals(result_user_id, user_id)
self.assertTrue(result_token is not None)
+ @override_config({"limit_usage_by_mau": False})
def test_mau_limits_when_disabled(self):
- self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "a", "display_name"))
+ @override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self):
- self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock(
- return_value=make_awaitable(self.hs.config.max_mau_value - 1)
+ return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "c", "User"))
+ @override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self):
- self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users)
)
@@ -198,15 +210,15 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
self.store.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
ResourceLimitError,
)
+ @override_config({"limit_usage_by_mau": True})
def test_register_mau_blocked(self):
- self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users)
)
@@ -215,16 +227,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
self.store.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
)
+ @override_config(
+ {"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False}
+ )
def test_auto_join_rooms_for_guests(self):
- room_alias_str = "#room:test"
- self.hs.config.auto_join_rooms = [room_alias_str]
- self.hs.config.auto_join_rooms_for_guests = False
user_id = self.get_success(
self.handler.register_user(localpart="jeff", make_guest=True),
)
@@ -243,34 +255,33 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(room_id["room_id"] in rooms)
self.assertEqual(len(rooms), 1)
+ @override_config({"auto_join_rooms": []})
def test_auto_create_auto_join_rooms_with_no_rooms(self):
- self.hs.config.auto_join_rooms = []
frank = UserID.from_string("@frank:test")
user_id = self.get_success(self.handler.register_user(frank.localpart))
self.assertEqual(user_id, frank.to_string())
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
+ @override_config({"auto_join_rooms": ["#room:another"]})
def test_auto_create_auto_join_where_room_is_another_domain(self):
- self.hs.config.auto_join_rooms = ["#room:another"]
frank = UserID.from_string("@frank:test")
user_id = self.get_success(self.handler.register_user(frank.localpart))
self.assertEqual(user_id, frank.to_string())
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
+ @override_config(
+ {"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False}
+ )
def test_auto_create_auto_join_where_auto_create_is_false(self):
- self.hs.config.autocreate_auto_join_rooms = False
- room_alias_str = "#room:test"
- self.hs.config.auto_join_rooms = [room_alias_str]
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
+ @override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self):
room_alias_str = "#room:test"
- self.hs.config.auto_join_rooms = [room_alias_str]
-
self.store.is_real_user = Mock(return_value=make_awaitable(False))
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -294,10 +305,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(room_id["room_id"] in rooms)
self.assertEqual(len(rooms), 1)
+ @override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self):
- room_alias_str = "#room:test"
- self.hs.config.auto_join_rooms = [room_alias_str]
-
self.store.count_real_users = Mock(return_value=make_awaitable(2))
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
@@ -510,6 +519,17 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(rooms, set())
self.assertEqual(invited_rooms, [])
+ @override_config(
+ {
+ "user_consent": {
+ "block_events_error": "Error",
+ "require_at_registration": True,
+ },
+ "form_secret": "53cr3t",
+ "public_baseurl": "http://test",
+ "auto_join_rooms": ["#room:test"],
+ },
+ )
def test_auto_create_auto_join_where_no_consent(self):
"""Test to ensure that the first user is not auto-joined to a room if
they have not given general consent.
@@ -521,25 +541,20 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# * The server is configured to auto-join to a room
# (and autocreate if necessary)
- event_creation_handler = self.hs.get_event_creation_handler()
- # (Messing with the internals of event_creation_handler is fragile
- # but can't see a better way to do this. One option could be to subclass
- # the test with custom config.)
- event_creation_handler._block_events_without_consent_error = "Error"
- event_creation_handler._consent_uri_builder = Mock()
- room_alias_str = "#room:test"
- self.hs.config.auto_join_rooms = [room_alias_str]
-
# When:-
- # * the user is registered and post consent actions are called
+ # * the user is registered
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
- self.get_success(self.handler.post_consent_actions(user_id))
# Then:-
# * Ensure that they have not been joined to the room
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
+ # The user provides consent; ensure they are now in the rooms.
+ self.get_success(self.handler.post_consent_actions(user_id))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertEqual(len(rooms), 1)
+
def test_register_support_user(self):
user_id = self.get_success(
self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT)
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 24b7ef6efc..56207f4db6 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -103,12 +103,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Do the initial population of the stats via the background update
self._add_background_updates()
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
def test_initial_room(self):
"""
@@ -140,12 +135,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._add_background_updates()
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
r = self.get_success(self.get_all_room_state())
@@ -568,12 +558,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
r1stats_complete = self._get_current_stats("room", r1)
u1stats_complete = self._get_current_stats("user", u1)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 266333c553..db65253773 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -11,47 +11,208 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Tuple
-from unittest.mock import Mock
+from typing import Tuple
+from unittest.mock import Mock, patch
from urllib.parse import quote
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.room_versions import RoomVersion, RoomVersions
-from synapse.rest.client import login, room, user_directory
+from synapse.appservice import ApplicationService
+from synapse.rest.client import login, register, room, user_directory
+from synapse.server import HomeServer
from synapse.storage.roommember import ProfileInfo
from synapse.types import create_requester
+from synapse.util import Clock
from tests import unittest
+from tests.storage.test_user_directory import GetUserDirectoryTables
+from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config
class UserDirectoryTestCase(unittest.HomeserverTestCase):
- """
- Tests the UserDirectoryHandler.
+ """Tests the UserDirectoryHandler.
+
+ We're broadly testing two kinds of things here.
+
+ 1. Check that we correctly update the user directory in response
+ to events (e.g. join a room, leave a room, change name, make public)
+ 2. Check that the search logic behaves as expected.
+
+ The background process that rebuilds the user directory is tested in
+ tests/storage/test_user_directory.py.
"""
servlets = [
login.register_servlets,
synapse.rest.admin.register_servlets,
+ register.register_servlets,
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
-
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["update_user_directory"] = True
- return self.setup_test_homeserver(config=config)
- def prepare(self, reactor, clock, hs):
+ self.appservice = ApplicationService(
+ token="i_am_an_app_service",
+ hostname="test",
+ id="1234",
+ namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ sender="@as:test",
+ )
+
+ mock_load_appservices = Mock(return_value=[self.appservice])
+ with patch(
+ "synapse.storage.databases.main.appservice.load_appservices",
+ mock_load_appservices,
+ ):
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
self.handler = hs.get_user_directory_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self.event_creation_handler = self.hs.get_event_creation_handler()
+ self.user_dir_helper = GetUserDirectoryTables(self.store)
+
+ def test_normal_user_pair(self) -> None:
+ """Sanity check that the room-sharing tables are updated correctly."""
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ bob = self.register_user("bob", "pass")
+ bob_token = self.login(bob, "pass")
+
+ public = self.helper.create_room_as(
+ alice,
+ is_public=True,
+ extra_content={"visibility": "public"},
+ tok=alice_token,
+ )
+ private = self.helper.create_room_as(alice, is_public=False, tok=alice_token)
+ self.helper.invite(private, alice, bob, tok=alice_token)
+ self.helper.join(public, bob, tok=bob_token)
+ self.helper.join(private, bob, tok=bob_token)
+
+ # Alice also makes a second public room but no-one else joins
+ public2 = self.helper.create_room_as(
+ alice,
+ is_public=True,
+ extra_content={"visibility": "public"},
+ tok=alice_token,
+ )
+
+ users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
+ in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ in_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+
+ self.assertEqual(users, {alice, bob})
+ self.assertEqual(
+ set(in_public), {(alice, public), (bob, public), (alice, public2)}
+ )
+ self.assertEqual(
+ self.user_dir_helper._compress_shared(in_private),
+ {(alice, bob, private), (bob, alice, private)},
+ )
+
+ # The next three tests (test_population_excludes_*) all setup
+ # - A normal user included in the user dir
+ # - A public and private room created by that user
+ # - A user excluded from the room dir, belonging to both rooms
+
+ # They match similar logic in storage/test_user_directory. But that tests
+ # rebuilding the directory; this tests updating it incrementally.
+
+ def test_excludes_support_user(self) -> None:
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ support = "@support1:test"
+ self.get_success(
+ self.store.register_user(
+ user_id=support, password_hash=None, user_type=UserTypes.SUPPORT
+ )
+ )
+
+ public, private = self._create_rooms_and_inject_memberships(
+ alice, alice_token, support
+ )
+ self._check_only_one_user_in_directory(alice, public)
+
+ def test_excludes_deactivated_user(self) -> None:
+ admin = self.register_user("admin", "pass", admin=True)
+ admin_token = self.login(admin, "pass")
+ user = self.register_user("naughty", "pass")
+
+ # Deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ f"/_synapse/admin/v2/users/{user}",
+ access_token=admin_token,
+ content={"deactivated": True},
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["deactivated"], True)
+
+ # Join the deactivated user to rooms owned by the admin.
+ # Is this something that could actually happen outside of a test?
+ public, private = self._create_rooms_and_inject_memberships(
+ admin, admin_token, user
+ )
+ self._check_only_one_user_in_directory(admin, public)
+
+ def test_excludes_appservices_user(self) -> None:
+ # Register an AS user.
+ user = self.register_user("user", "pass")
+ token = self.login(user, "pass")
+ as_user = self.register_appservice_user("as_user_potato", self.appservice.token)
+
+ # Join the AS user to rooms owned by the normal user.
+ public, private = self._create_rooms_and_inject_memberships(
+ user, token, as_user
+ )
+ self._check_only_one_user_in_directory(user, public)
+
+ def _create_rooms_and_inject_memberships(
+ self, creator: str, token: str, joiner: str
+ ) -> Tuple[str, str]:
+ """Create a public and private room as a normal user.
+ Then get the `joiner` into those rooms.
+ """
+ # TODO: Duplicates the same-named method in UserDirectoryInitialPopulationTest.
+ public_room = self.helper.create_room_as(
+ creator,
+ is_public=True,
+ # See https://github.com/matrix-org/synapse/issues/10951
+ extra_content={"visibility": "public"},
+ tok=token,
+ )
+ private_room = self.helper.create_room_as(creator, is_public=False, tok=token)
+
+ # HACK: get the user into these rooms
+ self.get_success(inject_member_event(self.hs, public_room, joiner, "join"))
+ self.get_success(inject_member_event(self.hs, private_room, joiner, "join"))
+
+ return public_room, private_room
+
+ def _check_only_one_user_in_directory(self, user: str, public: str) -> None:
+ users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
+ in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ in_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
- def test_handle_local_profile_change_with_support_user(self):
+ self.assertEqual(users, {user})
+ self.assertEqual(set(in_public), {(user, public)})
+ self.assertEqual(in_private, [])
+
+ def test_handle_local_profile_change_with_support_user(self) -> None:
support_user_id = "@support:test"
self.get_success(
self.store.register_user(
@@ -64,7 +225,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.handler.handle_local_profile_change(support_user_id, None)
+ self.handler.handle_local_profile_change(
+ support_user_id, ProfileInfo("I love support me", None)
+ )
)
profile = self.get_success(self.store.get_user_in_directory(support_user_id))
self.assertTrue(profile is None)
@@ -77,7 +240,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
self.assertTrue(profile["display_name"] == display_name)
- def test_handle_local_profile_change_with_deactivated_user(self):
+ def test_handle_local_profile_change_with_deactivated_user(self) -> None:
# create user
r_user_id = "@regular:test"
self.get_success(
@@ -112,7 +275,27 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
profile = self.get_success(self.store.get_user_in_directory(r_user_id))
self.assertTrue(profile is None)
- def test_handle_user_deactivated_support_user(self):
+ def test_handle_local_profile_change_with_appservice_user(self) -> None:
+ # create user
+ as_user_id = self.register_appservice_user(
+ "as_user_alice", self.appservice.token
+ )
+
+ # profile is not in directory
+ profile = self.get_success(self.store.get_user_in_directory(as_user_id))
+ self.assertTrue(profile is None)
+
+ # update profile
+ profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3")
+ self.get_success(
+ self.handler.handle_local_profile_change(as_user_id, profile_info)
+ )
+
+ # profile is still not in directory
+ profile = self.get_success(self.store.get_user_in_directory(as_user_id))
+ self.assertTrue(profile is None)
+
+ def test_handle_user_deactivated_support_user(self) -> None:
s_user_id = "@support:test"
self.get_success(
self.store.register_user(
@@ -120,20 +303,29 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
- self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None))
- self.get_success(self.handler.handle_local_user_deactivated(s_user_id))
- self.store.remove_from_user_dir.not_called()
+ mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
+ with patch.object(
+ self.store, "remove_from_user_dir", mock_remove_from_user_dir
+ ):
+ self.get_success(self.handler.handle_local_user_deactivated(s_user_id))
+ # BUG: the correct spelling is assert_not_called, but that makes the test fail
+ # and it's not clear that this is actually the behaviour we want.
+ mock_remove_from_user_dir.not_called()
- def test_handle_user_deactivated_regular_user(self):
+ def test_handle_user_deactivated_regular_user(self) -> None:
r_user_id = "@regular:test"
self.get_success(
self.store.register_user(user_id=r_user_id, password_hash=None)
)
- self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None))
- self.get_success(self.handler.handle_local_user_deactivated(r_user_id))
- self.store.remove_from_user_dir.called_once_with(r_user_id)
- def test_reactivation_makes_regular_user_searchable(self):
+ mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
+ with patch.object(
+ self.store, "remove_from_user_dir", mock_remove_from_user_dir
+ ):
+ self.get_success(self.handler.handle_local_user_deactivated(r_user_id))
+ mock_remove_from_user_dir.assert_called_once_with(r_user_id)
+
+ def test_reactivation_makes_regular_user_searchable(self) -> None:
user = self.register_user("regular", "pass")
user_token = self.login(user, "pass")
admin_user = self.register_user("admin", "pass", admin=True)
@@ -171,7 +363,147 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
self.assertEqual(s["results"][0]["user_id"], user)
- def test_private_room(self):
+ def test_process_join_after_server_leaves_room(self) -> None:
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ bob = self.register_user("bob", "pass")
+ bob_token = self.login(bob, "pass")
+
+ # Alice makes two rooms. Bob joins one of them.
+ room1 = self.helper.create_room_as(alice, tok=alice_token)
+ room2 = self.helper.create_room_as(alice, tok=alice_token)
+ self.helper.join(room1, bob, tok=bob_token)
+
+ # The user sharing tables should have been updated.
+ public1 = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ self.assertEqual(set(public1), {(alice, room1), (alice, room2), (bob, room1)})
+
+ # Alice leaves room1. The user sharing tables should be updated.
+ self.helper.leave(room1, alice, tok=alice_token)
+ public2 = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ self.assertEqual(set(public2), {(alice, room2), (bob, room1)})
+
+ # Pause the processing of new events.
+ dir_handler = self.hs.get_user_directory_handler()
+ dir_handler.update_user_directory = False
+
+ # Bob leaves one room and joins the other.
+ self.helper.leave(room1, bob, tok=bob_token)
+ self.helper.join(room2, bob, tok=bob_token)
+
+ # Process the leave and join in one go.
+ dir_handler.update_user_directory = True
+ dir_handler.notify_new_event()
+ self.wait_for_background_updates()
+
+ # The user sharing tables should have been updated.
+ public3 = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ self.assertEqual(set(public3), {(alice, room2), (bob, room2)})
+
+ def test_per_room_profile_doesnt_alter_directory_entry(self) -> None:
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ bob = self.register_user("bob", "pass")
+
+ # Alice should have a user directory entry created at registration.
+ users = self.get_success(self.user_dir_helper.get_profiles_in_user_directory())
+ self.assertEqual(
+ users[alice], ProfileInfo(display_name="alice", avatar_url=None)
+ )
+
+ # Alice makes a room for herself.
+ room = self.helper.create_room_as(alice, is_public=True, tok=alice_token)
+
+ # Alice sets a nickname unique to that room.
+ self.helper.send_state(
+ room,
+ "m.room.member",
+ {
+ "displayname": "Freddy Mercury",
+ "membership": "join",
+ },
+ alice_token,
+ state_key=alice,
+ )
+
+ # Alice's display name remains the same in the user directory.
+ search_result = self.get_success(self.handler.search_users(bob, alice, 10))
+ self.assertEqual(
+ search_result["results"],
+ [{"display_name": "alice", "avatar_url": None, "user_id": alice}],
+ 0,
+ )
+
+ def test_making_room_public_doesnt_alter_directory_entry(self) -> None:
+ """Per-room names shouldn't go to the directory when the room becomes public.
+
+ This isn't about preventing a leak (the room is now public, so the nickname
+ is too). It's about preserving the invariant that we only show a user's public
+ profile in the user directory results.
+
+ I made this a Synapse test case rather than a Complement one because
+ I think this is (strictly speaking) an implementation choice. Synapse
+ has chosen to only ever use the public profile when responding to a user
+ directory search. There's no privacy leak here, because making the room
+ public discloses the per-room name.
+
+ The spec doesn't mandate anything about _how_ a user
+ should appear in a /user_directory/search result. Hypothetical example:
+ suppose Bob searches for Alice. When representing Alice in a search
+ result, it's reasonable to use any of Alice's nicknames that Bob is
+ aware of. Heck, maybe we even want to use lots of them in a combined
+ displayname like `Alice (aka "ali", "ally", "41iC3")`.
+ """
+
+ # TODO the same should apply when Alice is a remote user.
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ bob = self.register_user("bob", "pass")
+ bob_token = self.login(bob, "pass")
+
+ # Alice and Bob are in a private room.
+ room = self.helper.create_room_as(alice, is_public=False, tok=alice_token)
+ self.helper.invite(room, src=alice, targ=bob, tok=alice_token)
+ self.helper.join(room, user=bob, tok=bob_token)
+
+ # Alice has a nickname unique to that room.
+
+ self.helper.send_state(
+ room,
+ "m.room.member",
+ {
+ "displayname": "Freddy Mercury",
+ "membership": "join",
+ },
+ alice_token,
+ state_key=alice,
+ )
+
+ # Check Alice isn't recorded as being in a public room.
+ public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ self.assertNotIn((alice, room), public)
+
+ # One of them makes the room public.
+ self.helper.send_state(
+ room,
+ "m.room.join_rules",
+ {"join_rule": "public"},
+ alice_token,
+ )
+
+ # Check that Alice is now recorded as being in a public room
+ public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ self.assertIn((alice, room), public)
+
+ # Alice's display name remains the same in the user directory.
+ search_result = self.get_success(self.handler.search_users(bob, alice, 10))
+ self.assertEqual(
+ search_result["results"],
+ [{"display_name": "alice", "avatar_url": None, "user_id": alice}],
+ 0,
+ )
+
+ def test_private_room(self) -> None:
"""
A user can be searched for only by people that are either in a public
room, or that share a private chat.
@@ -191,11 +523,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.helper.join(room, user=u2, tok=u2_token)
# Check we have populated the database correctly.
- shares_private = self.get_users_who_share_private_rooms()
- public_users = self.get_users_in_public_rooms()
+ shares_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ public_users = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
self.assertEqual(
- self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
+ self.user_dir_helper._compress_shared(shares_private),
+ {(u1, u2, room), (u2, u1, room)},
)
self.assertEqual(public_users, [])
@@ -215,10 +552,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.helper.leave(room, user=u2, tok=u2_token)
# Check we have removed the values.
- shares_private = self.get_users_who_share_private_rooms()
- public_users = self.get_users_in_public_rooms()
+ shares_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ public_users = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
- self.assertEqual(self._compress_shared(shares_private), set())
+ self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
self.assertEqual(public_users, [])
# User1 now gets no search results for any of the other users.
@@ -228,7 +569,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user3", 10))
self.assertEqual(len(s["results"]), 0)
- def test_spam_checker(self):
+ def test_spam_checker(self) -> None:
"""
A user which fails the spam checks will not appear in search results.
"""
@@ -246,11 +587,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.helper.join(room, user=u2, tok=u2_token)
# Check we have populated the database correctly.
- shares_private = self.get_users_who_share_private_rooms()
- public_users = self.get_users_in_public_rooms()
+ shares_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ public_users = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
self.assertEqual(
- self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
+ self.user_dir_helper._compress_shared(shares_private),
+ {(u1, u2, room), (u2, u1, room)},
)
self.assertEqual(public_users, [])
@@ -258,7 +604,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
- async def allow_all(user_profile):
+ async def allow_all(user_profile: ProfileInfo) -> bool:
# Allow all users.
return False
@@ -272,7 +618,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- async def block_all(user_profile):
+ async def block_all(user_profile: ProfileInfo) -> bool:
# All users are spammy.
return True
@@ -282,7 +628,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0)
- def test_legacy_spam_checker(self):
+ def test_legacy_spam_checker(self) -> None:
"""
A spam checker without the expected method should be ignored.
"""
@@ -300,11 +646,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.helper.join(room, user=u2, tok=u2_token)
# Check we have populated the database correctly.
- shares_private = self.get_users_who_share_private_rooms()
- public_users = self.get_users_in_public_rooms()
+ shares_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ public_users = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
self.assertEqual(
- self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
+ self.user_dir_helper._compress_shared(shares_private),
+ {(u1, u2, room), (u2, u1, room)},
)
self.assertEqual(public_users, [])
@@ -317,134 +668,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
- def _compress_shared(self, shared):
- """
- Compress a list of users who share rooms dicts to a list of tuples.
- """
- r = set()
- for i in shared:
- r.add((i["user_id"], i["other_user_id"], i["room_id"]))
- return r
-
- def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
- r = self.get_success(
- self.store.db_pool.simple_select_list(
- "users_in_public_rooms", None, ("user_id", "room_id")
- )
- )
- retval = []
- for i in r:
- retval.append((i["user_id"], i["room_id"]))
- return retval
-
- def get_users_who_share_private_rooms(self) -> List[Tuple[str, str, str]]:
- return self.get_success(
- self.store.db_pool.simple_select_list(
- "users_who_share_private_rooms",
- None,
- ["user_id", "other_user_id", "room_id"],
- )
- )
-
- def _add_background_updates(self):
- """
- Add the background updates we need to run.
- """
- # Ugh, have to reset this flag
- self.store.db_pool.updates._all_done = False
-
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
- {
- "update_name": "populate_user_directory_createtables",
- "progress_json": "{}",
- },
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
- {
- "update_name": "populate_user_directory_process_rooms",
- "progress_json": "{}",
- "depends_on": "populate_user_directory_createtables",
- },
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
- {
- "update_name": "populate_user_directory_process_users",
- "progress_json": "{}",
- "depends_on": "populate_user_directory_process_rooms",
- },
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
- {
- "update_name": "populate_user_directory_cleanup",
- "progress_json": "{}",
- "depends_on": "populate_user_directory_process_users",
- },
- )
- )
-
- def test_initial(self):
- """
- The user directory's initial handler correctly updates the search tables.
- """
- u1 = self.register_user("user1", "pass")
- u1_token = self.login(u1, "pass")
- u2 = self.register_user("user2", "pass")
- u2_token = self.login(u2, "pass")
- u3 = self.register_user("user3", "pass")
- u3_token = self.login(u3, "pass")
-
- room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
- self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
- self.helper.join(room, user=u2, tok=u2_token)
-
- private_room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
- self.helper.invite(private_room, src=u1, targ=u3, tok=u1_token)
- self.helper.join(private_room, user=u3, tok=u3_token)
-
- self.get_success(self.store.update_user_directory_stream_pos(None))
- self.get_success(self.store.delete_all_from_user_dir())
-
- shares_private = self.get_users_who_share_private_rooms()
- public_users = self.get_users_in_public_rooms()
-
- # Nothing updated yet
- self.assertEqual(shares_private, [])
- self.assertEqual(public_users, [])
-
- # Do the initial population of the user directory via the background update
- self._add_background_updates()
-
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
-
- shares_private = self.get_users_who_share_private_rooms()
- public_users = self.get_users_in_public_rooms()
-
- # User 1 and User 2 are in the same public room
- self.assertEqual(set(public_users), {(u1, room), (u2, room)})
-
- # User 1 and User 3 share private rooms
- self.assertEqual(
- self._compress_shared(shares_private),
- {(u1, u3, private_room), (u3, u1, private_room)},
- )
-
- def test_initial_share_all_users(self):
+ def test_initial_share_all_users(self) -> None:
"""
Search all users = True means that a user does not have to share a
private room with the searching user or be in a public room to be search
@@ -457,26 +681,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.register_user("user2", "pass")
u3 = self.register_user("user3", "pass")
- # Wipe the user dir
- self.get_success(self.store.update_user_directory_stream_pos(None))
- self.get_success(self.store.delete_all_from_user_dir())
-
- # Do the initial population of the user directory via the background update
- self._add_background_updates()
-
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
-
- shares_private = self.get_users_who_share_private_rooms()
- public_users = self.get_users_in_public_rooms()
+ shares_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ public_users = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
# No users share rooms
self.assertEqual(public_users, [])
- self.assertEqual(self._compress_shared(shares_private), set())
+ self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
# Despite not sharing a room, search_all_users means we get a search
# result.
@@ -501,7 +715,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
}
}
)
- def test_prefer_local_users(self):
+ def test_prefer_local_users(self) -> None:
"""Tests that local users are shown higher in search results when
user_directory.prefer_local_users is True.
"""
@@ -535,15 +749,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
local_users = [local_user_1, local_user_2, local_user_3]
remote_users = [remote_user_1, remote_user_2, remote_user_3]
- # Populate the user directory via background update
- self._add_background_updates()
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
-
# The local searching user searches for the term "user", which other users have
# in their user id
results = self.get_success(
@@ -565,7 +770,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
room_id: str,
room_version: RoomVersion,
user_id: str,
- ):
+ ) -> None:
# Add a user to the room.
builder = self.event_builder_factory.for_room_version(
room_version,
@@ -588,8 +793,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
- user_id = "@test:test"
-
servlets = [
user_directory.register_servlets,
room.register_servlets,
@@ -597,7 +800,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["update_user_directory"] = True
hs = self.setup_test_homeserver(config=config)
@@ -606,19 +809,24 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
return hs
- def test_disabling_room_list(self):
+ def test_disabling_room_list(self) -> None:
self.config.userdirectory.user_directory_search_enabled = True
- # First we create a room with another user so that user dir is non-empty
- # for our user
- self.helper.create_room_as(self.user_id)
+ # Create two users and put them in the same room.
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
u2 = self.register_user("user2", "pass")
- room = self.helper.create_room_as(self.user_id)
- self.helper.join(room, user=u2)
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
- # Assert user directory is not empty
+ # Each should see the other when searching the user directory.
channel = self.make_request(
- "POST", b"user_directory/search", b'{"search_term":"user2"}'
+ "POST",
+ b"user_directory/search",
+ b'{"search_term":"user2"}',
+ access_token=u1_token,
)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) > 0)
@@ -626,7 +834,10 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
# Disable user directory and check search returns nothing
self.config.userdirectory.user_directory_search_enabled = False
channel = self.make_request(
- "POST", b"user_directory/search", b'{"search_term":"user2"}'
+ "POST",
+ b"user_directory/search",
+ b'{"search_term":"user2"}',
+ access_token=u1_token,
)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) == 0)
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index d9a8b077d3..638babae69 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -226,7 +226,7 @@ class FederationClientTests(HomeserverTestCase):
"""Ensure that Synapse does not try to connect to blacklisted IPs"""
# Set up the ip_range blacklist
- self.hs.config.federation_ip_range_blacklist = IPSet(
+ self.hs.config.server.federation_ip_range_blacklist = IPSet(
["127.0.0.0/8", "fe80::/64"]
)
self.reactor.lookups["internal"] = "127.0.0.1"
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index f73fcd684e..96f399b7ab 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -198,3 +198,31 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
self.assertEqual(log["url"], "/_matrix/client/versions")
self.assertEqual(log["protocol"], "1.1")
self.assertEqual(log["user_agent"], "")
+
+ def test_with_exception(self):
+ """
+ The logging exception type & value should be added to the JSON response.
+ """
+ handler = logging.StreamHandler(self.output)
+ handler.setFormatter(JsonFormatter())
+ logger = self.get_logger(handler)
+
+ try:
+ raise ValueError("That's wrong, you wally!")
+ except ValueError:
+ logger.exception("Hello there, %s!", "wally")
+
+ log = self.get_log_line()
+
+ # The terse logger should give us these keys.
+ expected_log_keys = [
+ "log",
+ "level",
+ "namespace",
+ "exc_type",
+ "exc_value",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
+ self.assertEqual(log["exc_type"], "ValueError")
+ self.assertEqual(log["exc_value"], "That's wrong, you wally!")
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 9d38974fba..e915dd5c7c 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -25,6 +25,7 @@ from synapse.types import create_requester
from tests.events.test_presence_router import send_presence_update, sync_presence
from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.test_utils import simple_async_mock
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
from tests.utils import USE_POSTGRES_FOR_TESTS
@@ -46,8 +47,12 @@ class ModuleApiTestCase(HomeserverTestCase):
self.auth_handler = homeserver.get_auth_handler()
def make_homeserver(self, reactor, clock):
+ # Mock out the calls over federation.
+ fed_transport_client = Mock(spec=["send_transaction"])
+ fed_transport_client.send_transaction = simple_async_mock({})
+
return self.setup_test_homeserver(
- federation_transport_client=Mock(spec=["send_transaction"]),
+ federation_transport_client=fed_transport_client,
)
def test_can_register_user(self):
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index c7555c26db..eac4664b41 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -70,8 +70,16 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# databases objects are the same.
self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
+ # Normally we'd pass in the handler to `setup_test_homeserver`, which would
+ # eventually hit "Install @cache_in_self attributes" in tests/utils.py.
+ # Unfortunately our handler wants a reference to the homeserver. That leaves
+ # us with a chicken-and-egg problem.
+ # We can workaround this: create the homeserver first, create the handler
+ # and bodge it in after the fact. The bodging requires us to know the
+ # dirty details of how `cache_in_self` works. We politely ask mypy to
+ # ignore our dirty dealings.
self.test_handler = self._build_replication_data_handler()
- self.worker_hs._replication_data_handler = self.test_handler
+ self.worker_hs._replication_data_handler = self.test_handler # type: ignore[attr-defined]
repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol(
@@ -240,7 +248,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
- b"localhost",
+ "localhost",
6379,
self.connect_any_redis_attempts,
)
@@ -315,12 +323,15 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
)
)
+ # Copy the port into a new, non-Optional variable so mypy knows we're
+ # not going to reset `instance_loc` to `None` under its feet. See
+ # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
+ port = instance_loc.port
+
self.reactor.add_tcp_client_callback(
self.reactor.lookups[instance_loc.host],
instance_loc.port,
- lambda: self._handle_http_replication_attempt(
- worker_hs, instance_loc.port
- ),
+ lambda: self._handle_http_replication_attempt(worker_hs, port),
)
store = worker_hs.get_datastore()
@@ -424,7 +435,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
clients = self.reactor.tcpClients
while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
- self.assertEqual(host, b"localhost")
+ self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index ee3ae9cce4..6ed9e42173 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -59,7 +59,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.hs = self.setup_test_homeserver()
- self.hs.config.registration_shared_secret = "shared"
+ self.hs.config.registration.registration_shared_secret = "shared"
self.hs.get_media_repository = Mock()
self.hs.get_deactivate_account_handler = Mock()
@@ -71,7 +71,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
If there is no shared secret, registration through this method will be
prevented.
"""
- self.hs.config.registration_shared_secret = None
+ self.hs.config.registration.registration_shared_secret = None
channel = self.make_request("POST", self.url, b"{}")
@@ -422,7 +422,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
store.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -1485,7 +1485,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -1522,7 +1522,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 9e9e953cf4..89d85b0a17 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -470,13 +470,45 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
register.register_servlets,
]
+ def default_config(self):
+ config = super().default_config()
+ config["allow_guest_access"] = True
+ return config
+
def test_GET_whoami(self):
device_id = "wouldgohere"
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test", device_id=device_id)
- whoami = self.whoami(tok)
- self.assertEqual(whoami, {"user_id": user_id, "device_id": device_id})
+ whoami = self._whoami(tok)
+ self.assertEqual(
+ whoami,
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ # Unstable until MSC3069 enters spec
+ "org.matrix.msc3069.is_guest": False,
+ },
+ )
+
+ def test_GET_whoami_guests(self):
+ channel = self.make_request(
+ b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}"
+ )
+ tok = channel.json_body["access_token"]
+ user_id = channel.json_body["user_id"]
+ device_id = channel.json_body["device_id"]
+
+ whoami = self._whoami(tok)
+ self.assertEqual(
+ whoami,
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ # Unstable until MSC3069 enters spec
+ "org.matrix.msc3069.is_guest": True,
+ },
+ )
def test_GET_whoami_appservices(self):
user_id = "@as:test"
@@ -484,18 +516,25 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
as_token,
- self.hs.config.server_name,
+ self.hs.config.server.server_name,
id="1234",
namespaces={"users": [{"regex": user_id, "exclusive": True}]},
sender=user_id,
)
self.hs.get_datastore().services_cache.append(appservice)
- whoami = self.whoami(as_token)
- self.assertEqual(whoami, {"user_id": user_id})
+ whoami = self._whoami(as_token)
+ self.assertEqual(
+ whoami,
+ {
+ "user_id": user_id,
+ # Unstable until MSC3069 enters spec
+ "org.matrix.msc3069.is_guest": False,
+ },
+ )
self.assertFalse(hasattr(whoami, "device_id"))
- def whoami(self, tok):
+ def _whoami(self, tok):
channel = self.make_request("GET", "account/whoami", {}, access_token=tok)
self.assertEqual(channel.code, 200)
return channel.json_body
@@ -625,7 +664,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_add_email_if_disabled(self):
"""Test adding email to profile when doing so is disallowed"""
- self.hs.config.enable_3pid_changes = False
+ self.hs.config.registration.enable_3pid_changes = False
client_secret = "foobar"
session_id = self._request_token(self.email, client_secret)
@@ -695,7 +734,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_delete_email_if_disabled(self):
"""Test deleting an email from profile when disallowed"""
- self.hs.config.enable_3pid_changes = False
+ self.hs.config.registration.enable_3pid_changes = False
# Add a threepid
self.get_success(
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index 422361b62a..b9e3602552 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -55,7 +55,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version)
self.assertEqual(
- self.config.default_room_version.identifier,
+ self.config.server.default_room_version.identifier,
capabilities["m.room_versions"]["default"],
)
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index ca2e8ff8ef..becb4e8dcc 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -37,7 +37,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
return self.hs
def test_3pid_lookup_disabled(self):
- self.hs.config.enable_3pid_lookup = False
+ self.hs.config.registration.enable_3pid_lookup = False
self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 371615a015..a63f04bd41 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -94,9 +94,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
- self.hs.config.enable_registration = True
- self.hs.config.registrations_require_3pid = []
- self.hs.config.auto_join_rooms = []
+ self.hs.config.registration.enable_registration = True
+ self.hs.config.registration.registrations_require_3pid = []
+ self.hs.config.registration.auto_join_rooms = []
self.hs.config.captcha.enable_registration_captcha = False
return self.hs
@@ -1064,13 +1064,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
register.register_servlets,
]
- def register_as_user(self, username):
- self.make_request(
- b"POST",
- "/_matrix/client/r0/register?access_token=%s" % (self.service.token,),
- {"username": username},
- )
-
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
@@ -1107,7 +1100,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
def test_login_appservice_user(self):
"""Test that an appservice user can use /login"""
- self.register_as_user(AS_USER)
+ self.register_appservice_user(AS_USER, self.service.token)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
@@ -1121,7 +1114,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
def test_login_appservice_user_bot(self):
"""Test that the appservice bot can use /login"""
- self.register_as_user(AS_USER)
+ self.register_appservice_user(AS_USER, self.service.token)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
@@ -1135,7 +1128,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
def test_login_appservice_wrong_user(self):
"""Test that non-as users cannot login with the as token"""
- self.register_as_user(AS_USER)
+ self.register_appservice_user(AS_USER, self.service.token)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
@@ -1149,7 +1142,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
def test_login_appservice_wrong_as(self):
"""Test that as users cannot login with wrong as token"""
- self.register_as_user(AS_USER)
+ self.register_appservice_user(AS_USER, self.service.token)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
@@ -1165,7 +1158,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
"""Test that users must provide a token when using the appservice
login method
"""
- self.register_as_user(AS_USER)
+ self.register_appservice_user(AS_USER, self.service.token)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index 1d152352d1..56fe1a3d01 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -50,7 +50,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
PUT to the status endpoint with use_presence enabled will call
set_state on the presence handler.
"""
- self.hs.config.use_presence = True
+ self.hs.config.server.use_presence = True
body = {"presence": "here", "status_msg": "beep boop"}
channel = self.make_request(
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 72a5a11b46..66dcfc9f88 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -50,7 +50,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
as_token,
- self.hs.config.server_name,
+ self.hs.config.server.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
@@ -74,7 +74,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
as_token,
- self.hs.config.server_name,
+ self.hs.config.server.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
@@ -147,7 +147,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_guest_registration(self):
self.hs.config.key.macaroon_secret_key = "test"
- self.hs.config.allow_guest_access = True
+ self.hs.config.registration.allow_guest_access = True
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
@@ -156,7 +156,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_disabled_guest_registration(self):
- self.hs.config.allow_guest_access = False
+ self.hs.config.registration.allow_guest_access = False
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 30bdaa9c27..376853fd65 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -784,6 +784,30 @@ class RoomsCreateTestCase(RoomBase):
# Check that do_3pid_invite wasn't called this time.
self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+ def test_spam_checker_may_join_room(self):
+ """Tests that the user_may_join_room spam checker callback is correctly bypassed
+ when creating a new room.
+ """
+
+ async def user_may_join_room(
+ mxid: str,
+ room_id: str,
+ is_invite: bool,
+ ) -> bool:
+ return False
+
+ join_mock = Mock(side_effect=user_may_join_room)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
+
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ {},
+ )
+ self.assertEquals(channel.code, 200, channel.json_body)
+
+ self.assertEquals(join_mock.call_count, 0)
+
class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events."""
@@ -975,6 +999,83 @@ class RoomInviteRatelimitTestCase(RoomBase):
self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+class RoomJoinTestCase(RoomBase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user1 = self.register_user("thomas", "hackme")
+ self.tok1 = self.login("thomas", "hackme")
+
+ self.user2 = self.register_user("teresa", "hackme")
+ self.tok2 = self.login("teresa", "hackme")
+
+ self.room1 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+ self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+ self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+
+ def test_spam_checker_may_join_room(self):
+ """Tests that the user_may_join_room spam checker callback is correctly called
+ and blocks room joins when needed.
+ """
+
+ # Register a dummy callback. Make it allow all room joins for now.
+ return_value = True
+
+ async def user_may_join_room(
+ userid: str,
+ room_id: str,
+ is_invited: bool,
+ ) -> bool:
+ return return_value
+
+ callback_mock = Mock(side_effect=user_may_join_room)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
+
+ # Join a first room, without being invited to it.
+ self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room1,
+ False,
+ ),
+ )
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Join a second room, this time with an invite for it.
+ self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1)
+ self.helper.join(self.room2, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room2,
+ True,
+ ),
+ )
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Now make the callback deny all room joins, and check that a join actually fails.
+ return_value = False
+ self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
+
+
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
@@ -2430,3 +2531,73 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
"""An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
+
+
+class ThreepidInviteTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("thomas", "hackme")
+ self.tok = self.login("thomas", "hackme")
+
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ def test_threepid_invite_spamcheck(self):
+ # Mock a few functions to prevent the test from failing due to failing to talk to
+ # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
+ # can check its call_count later on during the test.
+ make_invite_mock = Mock(return_value=make_awaitable(0))
+ self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
+ self.hs.get_identity_handler().lookup_3pid = Mock(
+ return_value=make_awaitable(None),
+ )
+
+ # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
+ # allow everything for now.
+ mock = Mock(return_value=make_awaitable(True))
+ self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
+
+ # Send a 3PID invite into the room and check that it succeeded.
+ email_to_invite = "teresa@example.com"
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEquals(channel.code, 200)
+
+ # Check that the callback was called with the right params.
+ mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id)
+
+ # Check that the call to send the invite was made.
+ make_invite_mock.assert_called_once()
+
+ # Now change the return value of the callback to deny any invite and test that
+ # we can't send the invite.
+ mock.return_value = make_awaitable(False)
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEquals(channel.code, 403)
+
+ # Also check that it stopped before calling _make_and_store_3pid_invite.
+ make_invite_mock.assert_called_once()
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 3075d3f288..71fa87ce92 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -48,7 +48,7 @@ class RestHelper:
def create_room_as(
self,
room_creator: Optional[str] = None,
- is_public: bool = True,
+ is_public: Optional[bool] = None,
room_version: Optional[str] = None,
tok: Optional[str] = None,
expect_code: int = 200,
@@ -62,9 +62,10 @@ class RestHelper:
Args:
room_creator: The user ID to create the room with.
- is_public: If True, the `visibility` parameter will be set to the
- default (public). Otherwise, the `visibility` parameter will be set
- to "private".
+ is_public: If True, the `visibility` parameter will be set to
+ "public". If False, it will be set to "private". If left
+ unspecified, the server will set it to an appropriate default
+ (which should be "private" as per the CS spec).
room_version: The room version to create the room as. Defaults to Synapse's
default room version.
tok: The access token to use in the request.
@@ -77,8 +78,8 @@ class RestHelper:
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
content = extra_content or {}
- if not is_public:
- content["visibility"] = "private"
+ if is_public is not None:
+ content["visibility"] = "public" if is_public else "private"
if room_version:
content["room_version"] = room_version
if tok:
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 4d09b5d07e..8698135a76 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -21,11 +21,13 @@ from twisted.internet.error import DNSLookupError
from twisted.test.proto_helpers import AccumulatingProtocol
from synapse.config.oembed import OEmbedEndpointConfig
+from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
from synapse.util.stringutils import parse_and_validate_mxc_uri
from tests import unittest
from tests.server import FakeTransport
from tests.test_utils import SMALL_PNG
+from tests.utils import MockClock
try:
import lxml
@@ -723,9 +725,107 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
+ def test_oembed_autodiscovery(self):
+ """
+ Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
+ 1. Request a preview of a URL which is not known to the oEmbed code.
+ 2. It returns HTML including a link to an oEmbed preview.
+ 3. The oEmbed preview is requested and returns a URL for an image.
+ 4. The image is requested for thumbnailing.
+ """
+ # This is a little cheesy in that we use the www subdomain (which isn't the
+ # list of oEmbed patterns) to get "raw" HTML response.
+ self.lookups["www.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = b"""
+ <link rel="alternate" type="application/json+oembed"
+ href="http://publish.twitter.com/oembed?url=http%3A%2F%2Fcdn.twitter.com%2Fmatrixdotorg%2Fstatus%2F12345&format=json"
+ title="matrixdotorg" />
+ """
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(result),)
+ + result
+ )
+
+ self.pump()
+
+ # The oEmbed response.
+ result2 = {
+ "version": "1.0",
+ "type": "photo",
+ "url": "http://cdn.twitter.com/matrixdotorg",
+ }
+ oembed_content = json.dumps(result2).encode("utf-8")
+
+ # Ensure a second request is made to the oEmbed URL.
+ client = self.reactor.tcpClients[1][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+ )
+ % (len(oembed_content),)
+ + oembed_content
+ )
+
+ self.pump()
+
+ # Ensure the URL is what was requested.
+ self.assertIn(b"/oembed?", server.data)
+
+ # Ensure a third request is made to the photo URL.
+ client = self.reactor.tcpClients[2][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b"Content-Type: image/png\r\n\r\n"
+ )
+ % (len(SMALL_PNG),)
+ + SMALL_PNG
+ )
+
+ self.pump()
+
+ # Ensure the URL is what was requested.
+ self.assertIn(b"/matrixdotorg", server.data)
+
+ self.assertEqual(channel.code, 200)
+ body = channel.json_body
+ self.assertEqual(
+ body["og:url"], "http://www.twitter.com/matrixdotorg/status/12345"
+ )
+ self.assertTrue(body["og:image"].startswith("mxc://"))
+ self.assertEqual(body["og:image:height"], 1)
+ self.assertEqual(body["og:image:width"], 1)
+ self.assertEqual(body["og:image:type"], "image/png")
+
def _download_image(self):
"""Downloads an image into the URL cache.
-
Returns:
A (host, media_id) tuple representing the MXC URI of the image.
"""
@@ -851,3 +951,32 @@ class URLPreviewTests(unittest.HomeserverTestCase):
404,
"URL cache thumbnail was unexpectedly retrieved from a storage provider",
)
+
+ def test_cache_expiry(self):
+ """Test that URL cache files and thumbnails are cleaned up properly on expiry."""
+ self.preview_url.clock = MockClock()
+
+ _host, media_id = self._download_image()
+
+ file_path = self.preview_url.filepaths.url_cache_filepath(media_id)
+ file_dirs = self.preview_url.filepaths.url_cache_filepath_dirs_to_delete(
+ media_id
+ )
+ thumbnail_dir = self.preview_url.filepaths.url_cache_thumbnail_directory(
+ media_id
+ )
+ thumbnail_dirs = self.preview_url.filepaths.url_cache_thumbnail_dirs_to_delete(
+ media_id
+ )
+
+ self.assertTrue(os.path.isfile(file_path))
+ self.assertTrue(os.path.isdir(thumbnail_dir))
+
+ self.preview_url.clock.advance_time_msec(IMAGE_CACHE_EXPIRY_MS + 1)
+ self.get_success(self.preview_url._expire_url_cache_data())
+
+ for path in [file_path] + file_dirs + [thumbnail_dir] + thumbnail_dirs:
+ self.assertFalse(
+ os.path.exists(path),
+ f"{os.path.relpath(path, self.media_store_path)} was not deleted",
+ )
diff --git a/tests/server.py b/tests/server.py
index 88dfa8058e..64645651ce 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -317,7 +317,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def __init__(self):
self.threadpool = ThreadPool(self)
- self._tcp_callbacks = {}
+ self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
self._udp = []
self.lookups: Dict[str, str] = {}
self._thread_callbacks: Deque[Callable[[], None]] = deque()
@@ -355,7 +355,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def getThreadPool(self):
return self.threadpool
- def add_tcp_client_callback(self, host, port, callback):
+ def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
"""Add a callback that will be invoked when we receive a connection
attempt to the given IP/port using `connectTCP`.
@@ -364,7 +364,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
self._tcp_callbacks[(host, port)] = callback
- def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
+ def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
"""Fake L{IReactorTCP.connectTCP}."""
conn = super().connectTCP(
@@ -475,7 +475,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
return server
-def get_clock():
+def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
clock = ThreadedMemoryReactorClock()
hs_clock = Clock(clock)
return clock, hs_clock
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 7f25200a5d..36c495954f 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -346,7 +346,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
invites = []
# Register as many users as the MAU limit allows.
- for i in range(self.hs.config.max_mau_value):
+ for i in range(self.hs.config.server.max_mau_value):
localpart = "user%d" % i
user_id = self.register_user(localpart, "password")
tok = self.login(localpart, "password")
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index ffee707153..7496974da3 100644
--- a/tests/storage/databases/main/test_room.py
+++ b/tests/storage/databases/main/test_room.py
@@ -79,12 +79,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
# Make sure the background update filled in the room creator
room_creator_after = self.get_success(
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index cf9748f218..f26d5acf9c 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -126,7 +126,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.db_pool = database._db_pool
self.engine = database.engine
- db_config = hs.config.get_single_database()
+ db_config = hs.config.database.get_single_database()
self.store = TestTransactionStore(
database, make_conn(db_config, self.engine, "test"), hs
)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 7cc5e621ba..a59c28f896 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -66,12 +66,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
def test_soft_failed_extremities_handled_correctly(self):
"""Test that extremities are correctly calculated in the presence of
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 3cc8038f1e..dada4f98c9 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -242,12 +242,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_devices_last_seen_bg_update(self):
# First make sure we have completed all updates.
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
user_id = "@user:id"
device_id = "MY_DEVICE"
@@ -311,12 +306,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
# We should now get the correct result again
result = self.get_success(
@@ -337,12 +327,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_old_user_ips_pruned(self):
# First make sure we have completed all updates.
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
user_id = "@user:id"
device_id = "MY_DEVICE"
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 93136f0717..b31c5eb5ec 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -578,12 +578,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
@@ -619,12 +614,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 944dbc34a2..d6b4cdd788 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -51,7 +51,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
def test_initialise_reserved_users(self):
- threepids = self.hs.config.mau_limits_reserved_threepids
+ threepids = self.hs.config.server.mau_limits_reserved_threepids
# register three users, of which two have reserved 3pids, and a third
# which is a support user.
@@ -101,9 +101,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# XXX some of this is redundant. poking things into the config shouldn't
# work, and in any case it's not obvious what we expect to happen when
# we advance the reactor.
- self.hs.config.max_mau_value = 0
+ self.hs.config.server.max_mau_value = 0
self.reactor.advance(FORTY_DAYS)
- self.hs.config.max_mau_value = 5
+ self.hs.config.server.max_mau_value = 5
self.get_success(self.store.reap_monthly_active_users())
@@ -183,7 +183,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.get_success(d)
count = self.get_success(self.store.get_monthly_active_count())
- self.assertEqual(count, self.hs.config.max_mau_value)
+ self.assertEqual(count, self.hs.config.server.max_mau_value)
self.reactor.advance(FORTY_DAYS)
@@ -199,7 +199,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def test_reap_monthly_active_users_reserved_users(self):
"""Tests that reaping correctly handles reaping where reserved users are
present"""
- threepids = self.hs.config.mau_limits_reserved_threepids
+ threepids = self.hs.config.server.mau_limits_reserved_threepids
initial_users = len(threepids)
reserved_user_number = initial_users - 1
for i in range(initial_users):
@@ -234,7 +234,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.get_success(d)
count = self.get_success(self.store.get_monthly_active_count())
- self.assertEqual(count, self.hs.config.max_mau_value)
+ self.assertEqual(count, self.hs.config.server.max_mau_value)
def test_populate_monthly_users_is_guest(self):
# Test that guest users are not added to mau list
@@ -294,7 +294,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": user2_email},
]
- self.hs.config.mau_limits_reserved_threepids = threepids
+ self.hs.config.server.mau_limits_reserved_threepids = threepids
d = self.store.db_pool.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index c72dc40510..2873e22ccf 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -169,12 +169,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def test_can_rerun_update(self):
# First make sure we have completed all updates.
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
@@ -197,9 +192,4 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py
index 6ff3ebb137..ace82cbf42 100644
--- a/tests/storage/test_txn_limit.py
+++ b/tests/storage/test_txn_limit.py
@@ -22,7 +22,7 @@ class SQLTransactionLimitTestCase(unittest.HomeserverTestCase):
return self.setup_test_homeserver(db_txn_limit=1000)
def test_config(self):
- db_config = self.hs.config.get_single_database()
+ db_config = self.hs.config.database.get_single_database()
self.assertEqual(db_config.config["txn_limit"], 1000)
def test_select(self):
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 222e5d129d..9f483ad681 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -11,7 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, List, Set, Tuple
+from unittest import mock
+from unittest.mock import Mock, patch
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventTypes, Membership, UserTypes
+from synapse.appservice import ApplicationService
+from synapse.rest import admin
+from synapse.rest.client import login, register, room
+from synapse.server import HomeServer
+from synapse.storage import DataStore
+from synapse.storage.roommember import ProfileInfo
+from synapse.util import Clock
+
+from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
ALICE = "@alice:a"
@@ -21,8 +36,376 @@ BOBBY = "@bobby:a"
BELA = "@somenickname:a"
+class GetUserDirectoryTables:
+ """Helper functions that we want to reuse in tests/handlers/test_user_directory.py"""
+
+ def __init__(self, store: DataStore):
+ self.store = store
+
+ def _compress_shared(
+ self, shared: List[Dict[str, str]]
+ ) -> Set[Tuple[str, str, str]]:
+ """
+ Compress a list of users who share rooms dicts to a list of tuples.
+ """
+ r = set()
+ for i in shared:
+ r.add((i["user_id"], i["other_user_id"], i["room_id"]))
+ return r
+
+ async def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
+ """Fetch the entire `users_in_public_rooms` table.
+
+ Returns a list of tuples (user_id, room_id) where room_id is public and
+ contains the user with the given id.
+ """
+ r = await self.store.db_pool.simple_select_list(
+ "users_in_public_rooms", None, ("user_id", "room_id")
+ )
+
+ retval = []
+ for i in r:
+ retval.append((i["user_id"], i["room_id"]))
+ return retval
+
+ async def get_users_who_share_private_rooms(self) -> List[Dict[str, str]]:
+ """Fetch the entire `users_who_share_private_rooms` table.
+
+ Returns a dict containing "user_id", "other_user_id" and "room_id" keys.
+ The dicts can be flattened to Tuples with the `_compress_shared` method.
+ (This seems a little awkward---maybe we could clean this up.)
+ """
+
+ return await self.store.db_pool.simple_select_list(
+ "users_who_share_private_rooms",
+ None,
+ ["user_id", "other_user_id", "room_id"],
+ )
+
+ async def get_users_in_user_directory(self) -> Set[str]:
+ """Fetch the set of users in the `user_directory` table.
+
+ This is useful when checking we've correctly excluded users from the directory.
+ """
+ result = await self.store.db_pool.simple_select_list(
+ "user_directory",
+ None,
+ ["user_id"],
+ )
+ return {row["user_id"] for row in result}
+
+ async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]:
+ """Fetch users and their profiles from the `user_directory` table.
+
+ This is useful when we want to inspect display names and avatars.
+ It's almost the entire contents of the `user_directory` table: the only
+ thing missing is an unused room_id column.
+ """
+ rows = await self.store.db_pool.simple_select_list(
+ "user_directory",
+ None,
+ ("user_id", "display_name", "avatar_url"),
+ )
+ return {
+ row["user_id"]: ProfileInfo(
+ display_name=row["display_name"], avatar_url=row["avatar_url"]
+ )
+ for row in rows
+ }
+
+
+class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
+ """Ensure that rebuilding the directory writes the correct data to the DB.
+
+ See also tests/handlers/test_user_directory.py for similar checks. They
+ test the incremental updates, rather than the big rebuild.
+ """
+
+ servlets = [
+ login.register_servlets,
+ admin.register_servlets,
+ room.register_servlets,
+ register.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.appservice = ApplicationService(
+ token="i_am_an_app_service",
+ hostname="test",
+ id="1234",
+ namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ sender="@as:test",
+ )
+
+ mock_load_appservices = Mock(return_value=[self.appservice])
+ with patch(
+ "synapse.storage.databases.main.appservice.load_appservices",
+ mock_load_appservices,
+ ):
+ hs = super().make_homeserver(reactor, clock)
+ return hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastore()
+ self.user_dir_helper = GetUserDirectoryTables(self.store)
+
+ def _purge_and_rebuild_user_dir(self) -> None:
+ """Nuke the user directory tables, start the background process to
+ repopulate them, and wait for the process to complete. This allows us
+ to inspect the outcome of the background process alone, without any of
+ the other incremental updates.
+ """
+ self.get_success(self.store.update_user_directory_stream_pos(None))
+ self.get_success(self.store.delete_all_from_user_dir())
+
+ shares_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ public_users = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
+
+ # Nothing updated yet
+ self.assertEqual(shares_private, [])
+ self.assertEqual(public_users, [])
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_user_directory_createtables",
+ "progress_json": "{}",
+ },
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_user_directory_process_rooms",
+ "progress_json": "{}",
+ "depends_on": "populate_user_directory_createtables",
+ },
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_user_directory_process_users",
+ "progress_json": "{}",
+ "depends_on": "populate_user_directory_process_rooms",
+ },
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_user_directory_cleanup",
+ "progress_json": "{}",
+ "depends_on": "populate_user_directory_process_users",
+ },
+ )
+ )
+
+ self.wait_for_background_updates()
+
+ def test_initial(self) -> None:
+ """
+ The user directory's initial handler correctly updates the search tables.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+ u3 = self.register_user("user3", "pass")
+ u3_token = self.login(u3, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ private_room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
+ self.helper.invite(private_room, src=u1, targ=u3, tok=u1_token)
+ self.helper.join(private_room, user=u3, tok=u3_token)
+
+ # Do the initial population of the user directory via the background update
+ self._purge_and_rebuild_user_dir()
+
+ shares_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ public_users = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
+
+ # User 1 and User 2 are in the same public room
+ self.assertEqual(set(public_users), {(u1, room), (u2, room)})
+
+ # User 1 and User 3 share private rooms
+ self.assertEqual(
+ self.user_dir_helper._compress_shared(shares_private),
+ {(u1, u3, private_room), (u3, u1, private_room)},
+ )
+
+ # All three should have entries in the directory
+ users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
+ self.assertEqual(users, {u1, u2, u3})
+
+ # The next three tests (test_population_excludes_*) all set up
+ # - A normal user included in the user dir
+ # - A public and private room created by that user
+ # - A user excluded from the room dir, belonging to both rooms
+
+ # They match similar logic in handlers/test_user_directory.py But that tests
+ # updating the directory; this tests rebuilding it from scratch.
+
+ def _create_rooms_and_inject_memberships(
+ self, creator: str, token: str, joiner: str
+ ) -> Tuple[str, str]:
+ """Create a public and private room as a normal user.
+ Then get the `joiner` into those rooms.
+ """
+ public_room = self.helper.create_room_as(
+ creator,
+ is_public=True,
+ # See https://github.com/matrix-org/synapse/issues/10951
+ extra_content={"visibility": "public"},
+ tok=token,
+ )
+ private_room = self.helper.create_room_as(creator, is_public=False, tok=token)
+
+ # HACK: get the user into these rooms
+ self.get_success(inject_member_event(self.hs, public_room, joiner, "join"))
+ self.get_success(inject_member_event(self.hs, private_room, joiner, "join"))
+
+ return public_room, private_room
+
+ def _check_room_sharing_tables(
+ self, normal_user: str, public_room: str, private_room: str
+ ) -> None:
+ # After rebuilding the directory, we should only see the normal user.
+ users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
+ self.assertEqual(users, {normal_user})
+ in_public_rooms = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
+ self.assertEqual(set(in_public_rooms), {(normal_user, public_room)})
+ in_private_rooms = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ self.assertEqual(in_private_rooms, [])
+
+ def test_population_excludes_support_user(self) -> None:
+ # Create a normal and support user.
+ user = self.register_user("user", "pass")
+ token = self.login(user, "pass")
+ support = "@support1:test"
+ self.get_success(
+ self.store.register_user(
+ user_id=support, password_hash=None, user_type=UserTypes.SUPPORT
+ )
+ )
+
+ # Join the support user to rooms owned by the normal user.
+ public, private = self._create_rooms_and_inject_memberships(
+ user, token, support
+ )
+
+ # Rebuild the directory.
+ self._purge_and_rebuild_user_dir()
+
+ # Check the support user is not in the directory.
+ self._check_room_sharing_tables(user, public, private)
+
+ def test_population_excludes_deactivated_user(self) -> None:
+ user = self.register_user("naughty", "pass")
+ admin = self.register_user("admin", "pass", admin=True)
+ admin_token = self.login(admin, "pass")
+
+ # Deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ f"/_synapse/admin/v2/users/{user}",
+ access_token=admin_token,
+ content={"deactivated": True},
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["deactivated"], True)
+
+ # Join the deactivated user to rooms owned by the admin.
+ # Is this something that could actually happen outside of a test?
+ public, private = self._create_rooms_and_inject_memberships(
+ admin, admin_token, user
+ )
+
+ # Rebuild the user dir. The deactivated user should be missing.
+ self._purge_and_rebuild_user_dir()
+ self._check_room_sharing_tables(admin, public, private)
+
+ def test_population_excludes_appservice_user(self) -> None:
+ # Register an AS user.
+ user = self.register_user("user", "pass")
+ token = self.login(user, "pass")
+ as_user = self.register_appservice_user("as_user_potato", self.appservice.token)
+
+ # Join the AS user to rooms owned by the normal user.
+ public, private = self._create_rooms_and_inject_memberships(
+ user, token, as_user
+ )
+
+ # Rebuild the directory.
+ self._purge_and_rebuild_user_dir()
+
+ # Check the AS user is not in the directory.
+ self._check_room_sharing_tables(user, public, private)
+
+ def test_population_conceals_private_nickname(self) -> None:
+ # Make a private room, and set a nickname within
+ user = self.register_user("aaaa", "pass")
+ user_token = self.login(user, "pass")
+ private_room = self.helper.create_room_as(user, is_public=False, tok=user_token)
+ self.helper.send_state(
+ private_room,
+ EventTypes.Member,
+ state_key=user,
+ body={"membership": Membership.JOIN, "displayname": "BBBB"},
+ tok=user_token,
+ )
+
+ # Rebuild the user directory. Make the rescan of the `users` table a no-op
+ # so we only see the effect of scanning the `room_memberships` table.
+ async def mocked_process_users(*args: Any, **kwargs: Any) -> int:
+ await self.store.db_pool.updates._end_background_update(
+ "populate_user_directory_process_users"
+ )
+ return 1
+
+ with mock.patch.dict(
+ self.store.db_pool.updates._background_update_handlers,
+ populate_user_directory_process_users=mocked_process_users,
+ ):
+ self._purge_and_rebuild_user_dir()
+
+ # Local users are ignored by the scan over rooms
+ users = self.get_success(self.user_dir_helper.get_profiles_in_user_directory())
+ self.assertEqual(users, {})
+
+ # Do a full rebuild including the scan over the `users` table. The local
+ # user should appear with their profile name.
+ self._purge_and_rebuild_user_dir()
+ users = self.get_success(self.user_dir_helper.get_profiles_in_user_directory())
+ self.assertEqual(
+ users, {user: ProfileInfo(display_name="aaaa", avatar_url=None)}
+ )
+
+
class UserDirectoryStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
# alice and bob are both in !room_id. bobby is not but shares
@@ -33,7 +416,7 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
- def test_search_user_dir(self):
+ def test_search_user_dir(self) -> None:
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
@@ -44,7 +427,7 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
)
@override_config({"user_directory": {"search_all_users": True}})
- def test_search_user_dir_all_users(self):
+ def test_search_user_dir_all_users(self) -> None:
r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(2, len(r["results"]))
@@ -58,7 +441,7 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
)
@override_config({"user_directory": {"search_all_users": True}})
- def test_search_user_dir_stop_words(self):
+ def test_search_user_dir_stop_words(self) -> None:
"""Tests that a user can look up another user by searching for the start if its
display name even if that name happens to be a common English word that would
usually be ignored in full text searches.
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 1a4d078780..cf407c51cf 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -38,21 +38,19 @@ class EventAuthTestCase(unittest.TestCase):
}
# creator should be able to send state
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
_random_state_event(creator),
auth_events,
- do_sig_check=False,
)
# joiner should not be able to send state
self.assertRaises(
AuthError,
- event_auth.check,
+ event_auth.check_auth_rules_for_event,
RoomVersions.V1,
_random_state_event(joiner),
auth_events,
- do_sig_check=False,
)
def test_state_default_level(self):
@@ -77,19 +75,17 @@ class EventAuthTestCase(unittest.TestCase):
# pleb should not be able to send state
self.assertRaises(
AuthError,
- event_auth.check,
+ event_auth.check_auth_rules_for_event,
RoomVersions.V1,
_random_state_event(pleb),
auth_events,
- do_sig_check=False,
),
# king should be able to send state
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
_random_state_event(king),
auth_events,
- do_sig_check=False,
)
def test_alias_event(self):
@@ -102,37 +98,33 @@ class EventAuthTestCase(unittest.TestCase):
}
# creator should be able to send aliases
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
_alias_event(creator),
auth_events,
- do_sig_check=False,
)
# Reject an event with no state key.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
_alias_event(creator, state_key=""),
auth_events,
- do_sig_check=False,
)
# If the domain of the sender does not match the state key, reject.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
_alias_event(creator, state_key="test.com"),
auth_events,
- do_sig_check=False,
)
# Note that the member does *not* need to be in the room.
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
_alias_event(other),
auth_events,
- do_sig_check=False,
)
def test_msc2432_alias_event(self):
@@ -145,34 +137,30 @@ class EventAuthTestCase(unittest.TestCase):
}
# creator should be able to send aliases
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_alias_event(creator),
auth_events,
- do_sig_check=False,
)
# No particular checks are done on the state key.
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_alias_event(creator, state_key=""),
auth_events,
- do_sig_check=False,
)
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_alias_event(creator, state_key="test.com"),
auth_events,
- do_sig_check=False,
)
# Per standard auth rules, the member must be in the room.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_alias_event(other),
auth_events,
- do_sig_check=False,
)
def test_msc2209(self):
@@ -192,20 +180,18 @@ class EventAuthTestCase(unittest.TestCase):
}
# pleb should be able to modify the notifications power level.
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
_power_levels_event(pleb, {"notifications": {"room": 100}}),
auth_events,
- do_sig_check=False,
)
# But an MSC2209 room rejects this change.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_power_levels_event(pleb, {"notifications": {"room": 100}}),
auth_events,
- do_sig_check=False,
)
def test_join_rules_public(self):
@@ -222,59 +208,53 @@ class EventAuthTestCase(unittest.TestCase):
}
# Check join.
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user cannot be force-joined to a room.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_member_event(pleb, "join", sender=creator),
auth_events,
- do_sig_check=False,
)
# Banned should be rejected.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user who left can re-join.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user can send a join if they're in the room.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user can accept an invite.
auth_events[("m.room.member", pleb)] = _member_event(
pleb, "invite", sender=creator
)
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
def test_join_rules_invite(self):
@@ -292,60 +272,54 @@ class EventAuthTestCase(unittest.TestCase):
# A join without an invite is rejected.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user cannot be force-joined to a room.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_member_event(pleb, "join", sender=creator),
auth_events,
- do_sig_check=False,
)
# Banned should be rejected.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user who left cannot re-join.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user can send a join if they're in the room.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user can accept an invite.
auth_events[("m.room.member", pleb)] = _member_event(
pleb, "invite", sender=creator
)
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
def test_join_rules_msc3083_restricted(self):
@@ -370,11 +344,10 @@ class EventAuthTestCase(unittest.TestCase):
# Older room versions don't understand this join rule
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A properly formatted join event should work.
@@ -384,11 +357,10 @@ class EventAuthTestCase(unittest.TestCase):
EventContentFields.AUTHORISING_USER: "@creator:example.com"
},
)
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
authorised_join_event,
auth_events,
- do_sig_check=False,
)
# A join issued by a specific user works (i.e. the power level checks
@@ -400,7 +372,7 @@ class EventAuthTestCase(unittest.TestCase):
pl_auth_events[("m.room.member", "@inviter:foo.test")] = _join_event(
"@inviter:foo.test"
)
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
_join_event(
pleb,
@@ -409,16 +381,14 @@ class EventAuthTestCase(unittest.TestCase):
},
),
pl_auth_events,
- do_sig_check=False,
)
# A join which is missing an authorised server is rejected.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# An join authorised by a user who is not in the room is rejected.
@@ -427,7 +397,7 @@ class EventAuthTestCase(unittest.TestCase):
creator, {"invite": 100, "users": {"@other:example.com": 150}}
)
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
_join_event(
pleb,
@@ -436,13 +406,12 @@ class EventAuthTestCase(unittest.TestCase):
},
),
auth_events,
- do_sig_check=False,
)
# A user cannot be force-joined to a room. (This uses an event which
# *would* be valid, but is sent be a different user.)
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
_member_event(
pleb,
@@ -453,36 +422,32 @@ class EventAuthTestCase(unittest.TestCase):
},
),
auth_events,
- do_sig_check=False,
)
# Banned should be rejected.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
authorised_join_event,
auth_events,
- do_sig_check=False,
)
# A user who left can re-join.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
authorised_join_event,
auth_events,
- do_sig_check=False,
)
# A user can send a join if they're in the room. (This doesn't need to
# be authorised since the user is already joined.)
auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user can accept an invite. (This doesn't need to be authorised since
@@ -490,11 +455,10 @@ class EventAuthTestCase(unittest.TestCase):
auth_events[("m.room.member", pleb)] = _member_event(
pleb, "invite", sender=creator
)
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index c51e018da1..24fc77d7a7 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -82,7 +82,6 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
event,
context,
state=None,
- claimed_auth_event_map=None,
backfilled=False,
):
return context
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 66111eb367..c683c8937e 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -13,11 +13,11 @@
# limitations under the License.
"""Tests REST events for /rooms paths."""
-
+import synapse.rest.admin
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.appservice import ApplicationService
-from synapse.rest.client import register, sync
+from synapse.rest.client import login, profile, register, sync
from tests import unittest
from tests.unittest import override_config
@@ -26,7 +26,13 @@ from tests.utils import default_config
class TestMauLimit(unittest.HomeserverTestCase):
- servlets = [register.register_servlets, sync.register_servlets]
+ servlets = [
+ register.register_servlets,
+ sync.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ profile.register_servlets,
+ login.register_servlets,
+ ]
def default_config(self):
config = default_config("test")
@@ -165,7 +171,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
@override_config({"mau_trial_days": 1})
def test_trial_users_cant_come_back(self):
- self.hs.config.mau_trial_days = 1
+ self.hs.config.server.mau_trial_days = 1
# We should be able to register more than the limit initially
token1 = self.create_user("kermit1")
@@ -229,6 +235,31 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.reactor.advance(100)
self.assertEqual(2, self.successResultOf(count))
+ def test_deactivated_users_dont_count_towards_mau(self):
+ user1 = self.register_user("madonna", "password")
+ self.register_user("prince", "password2")
+ self.register_user("frodo", "onering", True)
+
+ token1 = self.login("madonna", "password")
+ token2 = self.login("prince", "password2")
+ admin_token = self.login("frodo", "onering")
+
+ self.do_sync_for_user(token1)
+ self.do_sync_for_user(token2)
+
+ # Check that mau count is what we expect
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 2)
+
+ # Deactivate user1
+ url = "/_synapse/admin/v1/deactivate/%s" % user1
+ channel = self.make_request("POST", url, access_token=admin_token)
+ self.assertIn("success", channel.json_body["id_server_unbind_result"])
+
+ # Check that deactivated user is no longer counted
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 1)
+
def create_user(self, localpart, token=None, appservice=False):
request_data = {
"username": localpart,
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 48e792b55b..09e017b4d9 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -13,7 +13,8 @@
# limitations under the License.
from synapse.rest.media.v1.preview_url_resource import (
- decode_and_calc_og,
+ _calc_og,
+ decode_body,
get_html_media_encoding,
summarize_paragraphs,
)
@@ -158,7 +159,8 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -173,7 +175,8 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -191,7 +194,8 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(
og,
@@ -212,7 +216,8 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -225,7 +230,8 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -239,7 +245,8 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
@@ -253,21 +260,22 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
def test_empty(self):
"""Test a body with no data in it."""
html = b""
- og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEqual(og, {})
+ tree = decode_body(html)
+ self.assertIsNone(tree)
def test_no_tree(self):
"""A valid body with no tree in it."""
html = b"\x00"
- og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEqual(og, {})
+ tree = decode_body(html)
+ self.assertIsNone(tree)
def test_invalid_encoding(self):
"""An invalid character encoding should be ignored and treated as UTF-8, if possible."""
@@ -279,9 +287,8 @@ class CalcOgTestCase(unittest.TestCase):
</body>
</html>
"""
- og = decode_and_calc_og(
- html, "http://example.com/test.html", "invalid-encoding"
- )
+ tree = decode_body(html, "invalid-encoding")
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding2(self):
@@ -295,7 +302,8 @@ class CalcOgTestCase(unittest.TestCase):
</body>
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
diff --git a/tests/unittest.py b/tests/unittest.py
index 7a6f5954d0..81c1a9e9d2 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import inspect
import logging
import secrets
import time
-from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
+from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
from unittest.mock import Mock, patch
from canonicaljson import json
@@ -28,6 +28,7 @@ from canonicaljson import json
from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
+from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
from twisted.web.resource import Resource
@@ -46,6 +47,7 @@ from synapse.logging.context import (
)
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
+from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -232,7 +234,7 @@ class HomeserverTestCase(TestCase):
# Honour the `use_frozen_dicts` config option. We have to do this
# manually because this is taken care of in the app `start` code, which
# we don't run. Plus we want to reset it on tearDown.
- events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts
+ events.USE_FROZEN_DICTS = self.hs.config.server.use_frozen_dicts
if self.hs is None:
raise Exception("No homeserver returned from make_homeserver.")
@@ -315,6 +317,15 @@ class HomeserverTestCase(TestCase):
self.reactor.advance(0.01)
time.sleep(0.01)
+ def wait_for_background_updates(self) -> None:
+ """Block until all background database updates have completed."""
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
def make_homeserver(self, reactor, clock):
"""
Make and return a homeserver.
@@ -371,7 +382,7 @@ class HomeserverTestCase(TestCase):
return config
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
"""
Prepare for the test. This involves things like mocking out parts of
the homeserver, or building test data common across the whole test
@@ -447,7 +458,7 @@ class HomeserverTestCase(TestCase):
client_ip,
)
- def setup_test_homeserver(self, *args, **kwargs):
+ def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer:
"""
Set up the test homeserver, meant to be called by the overridable
make_homeserver. It automatically passes through the test class's
@@ -558,7 +569,7 @@ class HomeserverTestCase(TestCase):
Returns:
The MXID of the new user.
"""
- self.hs.config.registration_shared_secret = "shared"
+ self.hs.config.registration.registration_shared_secret = "shared"
# Create the user
channel = self.make_request("GET", "/_synapse/admin/v1/register")
@@ -594,6 +605,35 @@ class HomeserverTestCase(TestCase):
user_id = channel.json_body["user_id"]
return user_id
+ def register_appservice_user(
+ self,
+ username: str,
+ appservice_token: str,
+ ) -> str:
+ """Register an appservice user as an application service.
+ Requires the client-facing registration API be registered.
+
+ Args:
+ username: the user to be registered by an application service.
+ Should be a full username, i.e. ""@localpart:hostname" as opposed to just "localpart"
+ appservice_token: the acccess token for that application service.
+
+ Raises: if the request to '/register' does not return 200 OK.
+
+ Returns: the MXID of the new user.
+ """
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/register",
+ {
+ "username": username,
+ "type": "m.login.application_service",
+ },
+ access_token=appservice_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ return channel.json_body["user_id"]
+
def login(
self,
username,
|