diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index db9f86bdac..04b8c2c07c 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -22,7 +22,7 @@ from synapse.appservice.scheduler import (
_ServiceQueuer,
_TransactionController,
)
-from synapse.util.logcontext import make_deferred_yieldable
+from synapse.logging.context import make_deferred_yieldable
from tests import unittest
diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py
new file mode 100644
index 0000000000..13ab282384
--- /dev/null
+++ b/tests/config/test_ratelimiting.py
@@ -0,0 +1,40 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.config.homeserver import HomeServerConfig
+
+from tests.unittest import TestCase
+from tests.utils import default_config
+
+
+class RatelimitConfigTestCase(TestCase):
+ def test_parse_rc_federation(self):
+ config_dict = default_config("test")
+ config_dict["rc_federation"] = {
+ "window_size": 20000,
+ "sleep_limit": 693,
+ "sleep_delay": 252,
+ "reject_limit": 198,
+ "concurrent": 7,
+ }
+
+ config = HomeServerConfig()
+ config.parse_config_dict(config_dict, "", "")
+ config_obj = config.rc_federation
+
+ self.assertEqual(config_obj.window_size, 20000)
+ self.assertEqual(config_obj.sleep_limit, 693)
+ self.assertEqual(config_obj.sleep_delay, 252)
+ self.assertEqual(config_obj.reject_limit, 198)
+ self.assertEqual(config_obj.concurrent, 7)
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index a5d88d644a..4f8a87a3df 100644
--- a/tests/config/test_tls.py
+++ b/tests/config/test_tls.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
+# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +16,10 @@
import os
-from synapse.config.tls import TlsConfig
+from OpenSSL import SSL
+
+from synapse.config.tls import ConfigError, TlsConfig
+from synapse.crypto.context_factory import ClientTLSOptionsFactory
from tests.unittest import TestCase
@@ -78,3 +82,112 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
"or use Synapse's ACME support to provision one."
),
)
+
+ def test_tls_client_minimum_default(self):
+ """
+ The default client TLS version is 1.0.
+ """
+ config = {}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(t.federation_client_minimum_tls_version, "1")
+
+ def test_tls_client_minimum_set(self):
+ """
+ The default client TLS version can be set to 1.0, 1.1, and 1.2.
+ """
+ config = {"federation_client_minimum_tls_version": 1}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.federation_client_minimum_tls_version, "1")
+
+ config = {"federation_client_minimum_tls_version": 1.1}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.federation_client_minimum_tls_version, "1.1")
+
+ config = {"federation_client_minimum_tls_version": 1.2}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.federation_client_minimum_tls_version, "1.2")
+
+ # Also test a string version
+ config = {"federation_client_minimum_tls_version": "1"}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.federation_client_minimum_tls_version, "1")
+
+ config = {"federation_client_minimum_tls_version": "1.2"}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.federation_client_minimum_tls_version, "1.2")
+
+ def test_tls_client_minimum_1_point_3_missing(self):
+ """
+ If TLS 1.3 support is missing and it's configured, it will raise a
+ ConfigError.
+ """
+ # thanks i hate it
+ if hasattr(SSL, "OP_NO_TLSv1_3"):
+ OP_NO_TLSv1_3 = SSL.OP_NO_TLSv1_3
+ delattr(SSL, "OP_NO_TLSv1_3")
+ self.addCleanup(setattr, SSL, "SSL.OP_NO_TLSv1_3", OP_NO_TLSv1_3)
+ assert not hasattr(SSL, "OP_NO_TLSv1_3")
+
+ config = {"federation_client_minimum_tls_version": 1.3}
+ t = TestConfig()
+ with self.assertRaises(ConfigError) as e:
+ t.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(
+ e.exception.args[0],
+ (
+ "federation_client_minimum_tls_version cannot be 1.3, "
+ "your OpenSSL does not support it"
+ ),
+ )
+
+ def test_tls_client_minimum_1_point_3_exists(self):
+ """
+ If TLS 1.3 support exists and it's configured, it will be settable.
+ """
+ # thanks i hate it, still
+ if not hasattr(SSL, "OP_NO_TLSv1_3"):
+ SSL.OP_NO_TLSv1_3 = 0x00
+ self.addCleanup(lambda: delattr(SSL, "OP_NO_TLSv1_3"))
+ assert hasattr(SSL, "OP_NO_TLSv1_3")
+
+ config = {"federation_client_minimum_tls_version": 1.3}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.federation_client_minimum_tls_version, "1.3")
+
+ def test_tls_client_minimum_set_passed_through_1_2(self):
+ """
+ The configured TLS version is correctly configured by the ContextFactory.
+ """
+ config = {"federation_client_minimum_tls_version": 1.2}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cf = ClientTLSOptionsFactory(t)
+
+ # The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2
+ self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0)
+ self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0)
+ self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0)
+
+ def test_tls_client_minimum_set_passed_through_1_0(self):
+ """
+ The configured TLS version is correctly configured by the ContextFactory.
+ """
+ config = {"federation_client_minimum_tls_version": 1}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cf = ClientTLSOptionsFactory(t)
+
+ # The context has not had any of the NO_TLS set.
+ self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0)
+ self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0)
+ self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0)
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 5a355f00cc..795703967d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -30,9 +30,12 @@ from synapse.crypto.keyring import (
ServerKeyFetcher,
StoreKeyFetcher,
)
+from synapse.logging.context import (
+ LoggingContext,
+ PreserveLoggingContext,
+ make_deferred_yieldable,
+)
from synapse.storage.keys import FetchKeyResult
-from synapse.util import logcontext
-from synapse.util.logcontext import LoggingContext
from tests import unittest
@@ -131,7 +134,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def get_perspectives(**kwargs):
self.assertEquals(LoggingContext.current_context().request, "11")
- with logcontext.PreserveLoggingContext():
+ with PreserveLoggingContext():
yield persp_deferred
defer.returnValue(persp_resp)
@@ -158,7 +161,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertFalse(res_deferreds[0].called)
res_deferreds[0].addBoth(self.check_context, None)
- yield logcontext.make_deferred_yieldable(res_deferreds[0])
+ yield make_deferred_yieldable(res_deferreds[0])
# let verify_json_objects_for_server finish its work before we kill the
# logcontext
@@ -184,7 +187,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
[("server10", json1, 0, "test")]
)
res_deferreds_2[0].addBoth(self.check_context, None)
- yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
+ yield make_deferred_yieldable(res_deferreds_2[0])
# let verify_json_objects_for_server finish its work before we kill the
# logcontext
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 4edce7af43..1c7ded7397 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -18,7 +18,7 @@ from mock import Mock
from twisted.internet import defer
from synapse.api.constants import UserTypes
-from synapse.api.errors import ResourceLimitError, SynapseError
+from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
from synapse.types import RoomAlias, UserID, create_requester
@@ -67,7 +67,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = frank.to_string()
requester = create_requester(user_id)
result_user_id, result_token = self.get_success(
- self.handler.get_or_create_user(requester, frank.localpart, "Frankie")
+ self.get_or_create_user(requester, frank.localpart, "Frankie")
)
self.assertEquals(result_user_id, user_id)
self.assertTrue(result_token is not None)
@@ -87,7 +87,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = frank.to_string()
requester = create_requester(user_id)
result_user_id, result_token = self.get_success(
- self.handler.get_or_create_user(requester, local_part, None)
+ self.get_or_create_user(requester, local_part, None)
)
self.assertEquals(result_user_id, user_id)
self.assertTrue(result_token is not None)
@@ -95,9 +95,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_mau_limits_when_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
- self.get_success(
- self.handler.get_or_create_user(self.requester, "a", "display_name")
- )
+ self.get_success(self.get_or_create_user(self.requester, "a", "display_name"))
def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True
@@ -105,7 +103,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return_value=defer.succeed(self.hs.config.max_mau_value - 1)
)
# Ensure does not throw exception
- self.get_success(self.handler.get_or_create_user(self.requester, "c", "User"))
+ self.get_success(self.get_or_create_user(self.requester, "c", "User"))
def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
@@ -113,7 +111,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return_value=defer.succeed(self.lots_of_users)
)
self.get_failure(
- self.handler.get_or_create_user(self.requester, "b", "display_name"),
+ self.get_or_create_user(self.requester, "b", "display_name"),
ResourceLimitError,
)
@@ -121,7 +119,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return_value=defer.succeed(self.hs.config.max_mau_value)
)
self.get_failure(
- self.handler.get_or_create_user(self.requester, "b", "display_name"),
+ self.get_or_create_user(self.requester, "b", "display_name"),
ResourceLimitError,
)
@@ -232,3 +230,55 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_invalid_user_id_length(self):
invalid_user_id = "x" * 256
self.get_failure(self.handler.register(localpart=invalid_user_id), SynapseError)
+
+ @defer.inlineCallbacks
+ def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
+ """Creates a new user if the user does not exist,
+ else revokes all previous access tokens and generates a new one.
+
+ XXX: this used to be in the main codebase, but was only used by this file,
+ so got moved here. TODO: get rid of it, probably
+
+ Args:
+ localpart : The local part of the user ID to register. If None,
+ one will be randomly generated.
+ Returns:
+ A tuple of (user_id, access_token).
+ Raises:
+ RegistrationError if there was a problem registering.
+ """
+ if localpart is None:
+ raise SynapseError(400, "Request must include user id")
+ yield self.hs.get_auth().check_auth_blocking()
+ need_register = True
+
+ try:
+ yield self.handler.check_username(localpart)
+ except SynapseError as e:
+ if e.errcode == Codes.USER_IN_USE:
+ need_register = False
+ else:
+ raise
+
+ user = UserID(localpart, self.hs.hostname)
+ user_id = user.to_string()
+ token = self.macaroon_generator.generate_access_token(user_id)
+
+ if need_register:
+ yield self.handler.register_with_store(
+ user_id=user_id,
+ token=token,
+ password_hash=password_hash,
+ create_profile_with_displayname=user.localpart,
+ )
+ else:
+ yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
+ yield self.store.add_access_token_to_user(user_id=user_id, token=token)
+
+ if displayname is not None:
+ # logger.info("setting user display name: %s -> %s", user_id, displayname)
+ yield self.hs.get_profile_handler().set_displayname(
+ user, requester, displayname, by_admin=True
+ )
+
+ defer.returnValue((user_id, token))
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 417fda3ab2..a49f9b3224 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -36,8 +36,8 @@ from synapse.http.federation.matrix_federation_agent import (
_cache_period_from_headers,
)
from synapse.http.federation.srv_resolver import Server
+from synapse.logging.context import LoggingContext
from synapse.util.caches.ttlcache import TTLCache
-from synapse.util.logcontext import LoggingContext
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.server import FakeTransport, ThreadedMemoryReactorClock
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index cf6c6e95b5..65b51dc981 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError
from twisted.names import dns, error
from synapse.http.federation.srv_resolver import SrvResolver
-from synapse.util.logcontext import LoggingContext
+from synapse.logging.context import LoggingContext
from tests import unittest
from tests.utils import MockClock
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index c4c0d9b968..b9d6d7ad1c 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -29,7 +29,7 @@ from synapse.http.matrixfederationclient import (
MatrixFederationHttpClient,
MatrixFederationRequest,
)
-from synapse.util.logcontext import LoggingContext
+from synapse.logging.context import LoggingContext
from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase
diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py
index ee0add3455..220884311c 100644
--- a/tests/patch_inline_callbacks.py
+++ b/tests/patch_inline_callbacks.py
@@ -28,7 +28,7 @@ def do_patch():
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
"""
- from synapse.util.logcontext import LoggingContext
+ from synapse.logging.context import LoggingContext
orig_inline_callbacks = defer.inlineCallbacks
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 22c3f73ef3..8ce6bb62da 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -18,8 +18,8 @@ from mock import Mock
from twisted.internet.defer import Deferred
import synapse.rest.admin
+from synapse.logging.context import make_deferred_yieldable
from synapse.rest.client.v1 import login, room
-from synapse.util.logcontext import make_deferred_yieldable
from tests.unittest import HomeserverTestCase
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 708dc26e61..a8adc9a61d 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -2,9 +2,9 @@ from mock import Mock, call
from twisted.internet import defer, reactor
+from synapse.logging.context import LoggingContext
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache
from synapse.util import Clock
-from synapse.util.logcontext import LoggingContext
from tests import unittest
from tests.utils import MockClock
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 3deeed3a70..6bb7d92638 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -466,9 +466,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["content"], new_body)
- self.assertEquals(
- channel.json_body["unsigned"].get("m.relations"),
- {RelationTypes.REPLACE: {"event_id": edit_event_id}},
+ relations_dict = channel.json_body["unsigned"].get("m.relations")
+ self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+ m_replace_dict = relations_dict[RelationTypes.REPLACE]
+ for key in ["event_id", "sender", "origin_server_ts"]:
+ self.assertIn(key, m_replace_dict)
+
+ self.assert_dict(
+ {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
def test_multi_edit(self):
@@ -518,9 +524,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["content"], new_body)
- self.assertEquals(
- channel.json_body["unsigned"].get("m.relations"),
- {RelationTypes.REPLACE: {"event_id": edit_event_id}},
+ relations_dict = channel.json_body["unsigned"].get("m.relations")
+ self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+ m_replace_dict = relations_dict[RelationTypes.REPLACE]
+ for key in ["event_id", "sender", "origin_server_ts"]:
+ self.assertIn(key, m_replace_dict)
+
+ self.assert_dict(
+ {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
def _send_relation(
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index e2d418b1df..bc662b61db 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -22,27 +22,28 @@ from binascii import unhexlify
from mock import Mock
from six.moves.urllib import parse
-from twisted.internet import defer, reactor
from twisted.internet.defer import Deferred
+from synapse.logging.context import make_deferred_yieldable
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
-from synapse.util.logcontext import make_deferred_yieldable
from tests import unittest
-class MediaStorageTests(unittest.TestCase):
- def setUp(self):
+class MediaStorageTests(unittest.HomeserverTestCase):
+
+ needs_threadpool = True
+
+ def prepare(self, reactor, clock, hs):
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
+ self.addCleanup(shutil.rmtree, self.test_dir)
self.primary_base_path = os.path.join(self.test_dir, "primary")
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
- hs = Mock()
- hs.get_reactor = Mock(return_value=reactor)
hs.config.media_store_path = self.primary_base_path
storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)]
@@ -52,10 +53,6 @@ class MediaStorageTests(unittest.TestCase):
hs, self.primary_base_path, self.filepaths, storage_providers
)
- def tearDown(self):
- shutil.rmtree(self.test_dir)
-
- @defer.inlineCallbacks
def test_ensure_media_is_in_local_cache(self):
media_id = "some_media_id"
test_body = "Test\n"
@@ -73,7 +70,15 @@ class MediaStorageTests(unittest.TestCase):
# Now we run ensure_media_is_in_local_cache, which should copy the file
# to the local cache.
file_info = FileInfo(None, media_id)
- local_path = yield self.media_storage.ensure_media_is_in_local_cache(file_info)
+
+ # This uses a real blocking threadpool so we have to wait for it to be
+ # actually done :/
+ x = self.media_storage.ensure_media_is_in_local_cache(file_info)
+
+ # Hotloop until the threadpool does its job...
+ self.wait_on_thread(x)
+
+ local_path = self.get_success(x)
self.assertTrue(os.path.exists(local_path))
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 8fe5961866..976652aee8 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -460,3 +460,15 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"error": "DNS resolution failure during URL preview generation",
},
)
+
+ def test_OPTIONS(self):
+ """
+ OPTIONS returns the OPTIONS.
+ """
+ request, channel = self.make_request(
+ "OPTIONS", "url_preview?url=http://example.com", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body, {})
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 6a8339b561..a73f18f88e 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -3,9 +3,9 @@ from mock import Mock
from twisted.internet.defer import maybeDeferred, succeed
from synapse.events import FrozenEvent
+from synapse.logging.context import LoggingContext
from synapse.types import Requester, UserID
from synapse.util import Clock
-from synapse.util.logcontext import LoggingContext
from tests import unittest
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
diff --git a/tests/test_server.py b/tests/test_server.py
index da29ae92ce..ba08483a4b 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -26,8 +26,8 @@ from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite, logger
+from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
-from synapse.util.logcontext import make_deferred_yieldable
from tests import unittest
from tests.server import (
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 813f984199..2d96b0fa8d 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -17,7 +17,7 @@ import os
import twisted.logger
-from synapse.util.logcontext import LoggingContextFilter
+from synapse.logging.context import LoggingContextFilter
class ToTwistedHandler(logging.Handler):
diff --git a/tests/unittest.py b/tests/unittest.py
index 36df43c137..a09e76c7c2 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -17,6 +17,7 @@ import gc
import hashlib
import hmac
import logging
+import time
from mock import Mock
@@ -24,16 +25,17 @@ from canonicaljson import json
import twisted
import twisted.logger
-from twisted.internet.defer import Deferred
+from twisted.internet.defer import Deferred, succeed
+from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
from synapse.api.constants import EventTypes
from synapse.config.homeserver import HomeServerConfig
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
+from synapse.logging.context import LoggingContext
from synapse.server import HomeServer
from synapse.types import Requester, UserID, create_requester
-from synapse.util.logcontext import LoggingContext
from tests.server import get_clock, make_request, render, setup_test_homeserver
from tests.test_utils.logging_setup import setup_logging
@@ -164,6 +166,7 @@ class HomeserverTestCase(TestCase):
servlets = []
hijack_auth = True
+ needs_threadpool = False
def setUp(self):
"""
@@ -192,15 +195,19 @@ class HomeserverTestCase(TestCase):
if self.hijack_auth:
def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.helper.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
+ return succeed(
+ {
+ "user": UserID.from_string(self.helper.auth_user_id),
+ "token_id": 1,
+ "is_guest": False,
+ }
+ )
def get_user_by_req(request, allow_guest=False, rights="access"):
- return create_requester(
- UserID.from_string(self.helper.auth_user_id), 1, False, None
+ return succeed(
+ create_requester(
+ UserID.from_string(self.helper.auth_user_id), 1, False, None
+ )
)
self.hs.get_auth().get_user_by_req = get_user_by_req
@@ -209,9 +216,26 @@ class HomeserverTestCase(TestCase):
return_value="1234"
)
+ if self.needs_threadpool:
+ self.reactor.threadpool = ThreadPool()
+ self.addCleanup(self.reactor.threadpool.stop)
+ self.reactor.threadpool.start()
+
if hasattr(self, "prepare"):
self.prepare(self.reactor, self.clock, self.hs)
+ def wait_on_thread(self, deferred, timeout=10):
+ """
+ Wait until a Deferred is done, where it's waiting on a real thread.
+ """
+ start_time = time.time()
+
+ while not deferred.called:
+ if start_time + timeout < time.time():
+ raise ValueError("Timed out waiting for threadpool")
+ self.reactor.advance(0.01)
+ time.sleep(0.01)
+
def make_homeserver(self, reactor, clock):
"""
Make and return a homeserver.
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 6f8f52537c..7807328e2f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -21,7 +21,11 @@ import mock
from twisted.internet import defer, reactor
from synapse.api.errors import SynapseError
-from synapse.util import logcontext
+from synapse.logging.context import (
+ LoggingContext,
+ PreserveLoggingContext,
+ make_deferred_yieldable,
+)
from synapse.util.caches import descriptors
from tests import unittest
@@ -32,7 +36,7 @@ logger = logging.getLogger(__name__)
def run_on_reactor():
d = defer.Deferred()
reactor.callLater(0, d.callback, 0)
- return logcontext.make_deferred_yieldable(d)
+ return make_deferred_yieldable(d)
class CacheTestCase(unittest.TestCase):
@@ -153,7 +157,7 @@ class DescriptorTestCase(unittest.TestCase):
def fn(self, arg1):
@defer.inlineCallbacks
def inner_fn():
- with logcontext.PreserveLoggingContext():
+ with PreserveLoggingContext():
yield complete_lookup
defer.returnValue(1)
@@ -161,10 +165,10 @@ class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def do_lookup():
- with logcontext.LoggingContext() as c1:
+ with LoggingContext() as c1:
c1.name = "c1"
r = yield obj.fn(1)
- self.assertEqual(logcontext.LoggingContext.current_context(), c1)
+ self.assertEqual(LoggingContext.current_context(), c1)
defer.returnValue(r)
def check_result(r):
@@ -174,18 +178,12 @@ class DescriptorTestCase(unittest.TestCase):
# set off a deferred which will do a cache lookup
d1 = do_lookup()
- self.assertEqual(
- logcontext.LoggingContext.current_context(),
- logcontext.LoggingContext.sentinel,
- )
+ self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
d1.addCallback(check_result)
# and another
d2 = do_lookup()
- self.assertEqual(
- logcontext.LoggingContext.current_context(),
- logcontext.LoggingContext.sentinel,
- )
+ self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
d2.addCallback(check_result)
# let the lookup complete
@@ -210,29 +208,25 @@ class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def do_lookup():
- with logcontext.LoggingContext() as c1:
+ with LoggingContext() as c1:
c1.name = "c1"
try:
d = obj.fn(1)
self.assertEqual(
- logcontext.LoggingContext.current_context(),
- logcontext.LoggingContext.sentinel,
+ LoggingContext.current_context(), LoggingContext.sentinel
)
yield d
self.fail("No exception thrown")
except SynapseError:
pass
- self.assertEqual(logcontext.LoggingContext.current_context(), c1)
+ self.assertEqual(LoggingContext.current_context(), c1)
obj = Cls()
# set off a deferred which will do a cache lookup
d1 = do_lookup()
- self.assertEqual(
- logcontext.LoggingContext.current_context(),
- logcontext.LoggingContext.sentinel,
- )
+ self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
return d1
@@ -288,23 +282,20 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
- assert logcontext.LoggingContext.current_context().request == "c1"
+ assert LoggingContext.current_context().request == "c1"
# we want this to behave like an asynchronous function
yield run_on_reactor()
- assert logcontext.LoggingContext.current_context().request == "c1"
+ assert LoggingContext.current_context().request == "c1"
defer.returnValue(self.mock(args1, arg2))
- with logcontext.LoggingContext() as c1:
+ with LoggingContext() as c1:
c1.request = "c1"
obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"}
d1 = obj.list_fn([10, 20], 2)
- self.assertEqual(
- logcontext.LoggingContext.current_context(),
- logcontext.LoggingContext.sentinel,
- )
+ self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
r = yield d1
- self.assertEqual(logcontext.LoggingContext.current_context(), c1)
+ self.assertEqual(LoggingContext.current_context(), c1)
obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r, {10: "fish", 20: "chips"})
obj.mock.reset_mock()
diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py
index bf85d3b8ec..f60918069a 100644
--- a/tests/util/test_async_utils.py
+++ b/tests/util/test_async_utils.py
@@ -16,9 +16,8 @@ from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.task import Clock
-from synapse.util import logcontext
+from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.util.async_helpers import timeout_deferred
-from synapse.util.logcontext import LoggingContext
from tests.unittest import TestCase
@@ -69,14 +68,14 @@ class TimeoutDeferredTest(TestCase):
@defer.inlineCallbacks
def blocking():
non_completing_d = Deferred()
- with logcontext.PreserveLoggingContext():
+ with PreserveLoggingContext():
try:
yield non_completing_d
except CancelledError:
blocking_was_cancelled[0] = True
raise
- with logcontext.LoggingContext("one") as context_one:
+ with LoggingContext("one") as context_one:
# the errbacks should be run in the test logcontext
def errback(res, deferred_name):
self.assertIs(
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index ec7ba9719c..0ec8ef90ce 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -19,7 +19,8 @@ from six.moves import range
from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError
-from synapse.util import Clock, logcontext
+from synapse.logging.context import LoggingContext
+from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
from tests import unittest
@@ -51,13 +52,13 @@ class LinearizerTestCase(unittest.TestCase):
@defer.inlineCallbacks
def func(i, sleep=False):
- with logcontext.LoggingContext("func(%s)" % i) as lc:
+ with LoggingContext("func(%s)" % i) as lc:
with (yield linearizer.queue("")):
- self.assertEqual(logcontext.LoggingContext.current_context(), lc)
+ self.assertEqual(LoggingContext.current_context(), lc)
if sleep:
yield Clock(reactor).sleep(0)
- self.assertEqual(logcontext.LoggingContext.current_context(), lc)
+ self.assertEqual(LoggingContext.current_context(), lc)
func(0, sleep=True)
for i in range(1, 100):
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 8adaee3c8d..8b8455c8b7 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -1,8 +1,14 @@
import twisted.python.failure
from twisted.internet import defer, reactor
-from synapse.util import Clock, logcontext
-from synapse.util.logcontext import LoggingContext
+from synapse.logging.context import (
+ LoggingContext,
+ PreserveLoggingContext,
+ make_deferred_yieldable,
+ nested_logging_context,
+ run_in_background,
+)
+from synapse.util import Clock
from .. import unittest
@@ -39,24 +45,17 @@ class LoggingContextTestCase(unittest.TestCase):
callback_completed = [False]
- def test():
+ with LoggingContext() as context_one:
context_one.request = "one"
- d = function()
+
+ # fire off function, but don't wait on it.
+ d2 = run_in_background(function)
def cb(res):
- self._check_test_key("one")
callback_completed[0] = True
return res
- d.addCallback(cb)
-
- return d
-
- with LoggingContext() as context_one:
- context_one.request = "one"
-
- # fire off function, but don't wait on it.
- logcontext.run_in_background(test)
+ d2.addCallback(cb)
self._check_test_key("one")
@@ -92,7 +91,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_run_in_background_with_non_blocking_fn(self):
@defer.inlineCallbacks
def nonblocking_function():
- with logcontext.PreserveLoggingContext():
+ with PreserveLoggingContext():
yield defer.succeed(None)
return self._test_run_in_background(nonblocking_function)
@@ -101,7 +100,23 @@ class LoggingContextTestCase(unittest.TestCase):
# a function which returns a deferred which looks like it has been
# called, but is actually paused
def testfunc():
- return logcontext.make_deferred_yieldable(_chained_deferred_function())
+ return make_deferred_yieldable(_chained_deferred_function())
+
+ return self._test_run_in_background(testfunc)
+
+ def test_run_in_background_with_coroutine(self):
+ async def testfunc():
+ self._check_test_key("one")
+ d = Clock(reactor).sleep(0)
+ self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ await d
+ self._check_test_key("one")
+
+ return self._test_run_in_background(testfunc)
+
+ def test_run_in_background_with_nonblocking_coroutine(self):
+ async def testfunc():
+ self._check_test_key("one")
return self._test_run_in_background(testfunc)
@@ -119,7 +134,7 @@ class LoggingContextTestCase(unittest.TestCase):
with LoggingContext() as context_one:
context_one.request = "one"
- d1 = logcontext.make_deferred_yieldable(blocking_function())
+ d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(LoggingContext.current_context(), sentinel_context)
@@ -135,7 +150,7 @@ class LoggingContextTestCase(unittest.TestCase):
with LoggingContext() as context_one:
context_one.request = "one"
- d1 = logcontext.make_deferred_yieldable(_chained_deferred_function())
+ d1 = make_deferred_yieldable(_chained_deferred_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(LoggingContext.current_context(), sentinel_context)
@@ -152,7 +167,7 @@ class LoggingContextTestCase(unittest.TestCase):
with LoggingContext() as context_one:
context_one.request = "one"
- d1 = logcontext.make_deferred_yieldable("bum")
+ d1 = make_deferred_yieldable("bum")
self._check_test_key("one")
r = yield d1
@@ -161,7 +176,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_nested_logging_context(self):
with LoggingContext(request="foo"):
- nested_context = logcontext.nested_logging_context(suffix="bar")
+ nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.request, "foo-bar")
diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py
index 297aebbfbe..0fb60caacb 100644
--- a/tests/util/test_logformatter.py
+++ b/tests/util/test_logformatter.py
@@ -14,7 +14,7 @@
# limitations under the License.
import sys
-from synapse.util.logformatter import LogFormatter
+from synapse.logging.formatter import LogFormatter
from tests import unittest
diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py
new file mode 100644
index 0000000000..4d1aee91d5
--- /dev/null
+++ b/tests/util/test_ratelimitutils.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.config.homeserver import HomeServerConfig
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+from tests.server import get_clock
+from tests.unittest import TestCase
+from tests.utils import default_config
+
+
+class FederationRateLimiterTestCase(TestCase):
+ def test_ratelimit(self):
+ """A simple test with the default values"""
+ reactor, clock = get_clock()
+ rc_config = build_rc_config()
+ ratelimiter = FederationRateLimiter(clock, rc_config)
+
+ with ratelimiter.ratelimit("testhost") as d1:
+ # shouldn't block
+ self.successResultOf(d1)
+
+ def test_concurrent_limit(self):
+ """Test what happens when we hit the concurrent limit"""
+ reactor, clock = get_clock()
+ rc_config = build_rc_config({"rc_federation": {"concurrent": 2}})
+ ratelimiter = FederationRateLimiter(clock, rc_config)
+
+ with ratelimiter.ratelimit("testhost") as d1:
+ # shouldn't block
+ self.successResultOf(d1)
+
+ cm2 = ratelimiter.ratelimit("testhost")
+ d2 = cm2.__enter__()
+ # also shouldn't block
+ self.successResultOf(d2)
+
+ cm3 = ratelimiter.ratelimit("testhost")
+ d3 = cm3.__enter__()
+ # this one should block, though ...
+ self.assertNoResult(d3)
+
+ # ... until we complete an earlier request
+ cm2.__exit__(None, None, None)
+ self.successResultOf(d3)
+
+ def test_sleep_limit(self):
+ """Test what happens when we hit the sleep limit"""
+ reactor, clock = get_clock()
+ rc_config = build_rc_config(
+ {"rc_federation": {"sleep_limit": 2, "sleep_delay": 500}}
+ )
+ ratelimiter = FederationRateLimiter(clock, rc_config)
+
+ with ratelimiter.ratelimit("testhost") as d1:
+ # shouldn't block
+ self.successResultOf(d1)
+
+ with ratelimiter.ratelimit("testhost") as d2:
+ # nor this
+ self.successResultOf(d2)
+
+ with ratelimiter.ratelimit("testhost") as d3:
+ # this one should block, though ...
+ self.assertNoResult(d3)
+ sleep_time = _await_resolution(reactor, d3)
+ self.assertAlmostEqual(sleep_time, 500, places=3)
+
+
+def _await_resolution(reactor, d):
+ """advance the clock until the deferred completes.
+
+ Returns the number of milliseconds it took to complete.
+ """
+ start_time = reactor.seconds()
+ while not d.called:
+ reactor.advance(0.01)
+ return (reactor.seconds() - start_time) * 1000
+
+
+def build_rc_config(settings={}):
+ config_dict = default_config("test")
+ config_dict.update(settings)
+ config = HomeServerConfig()
+ config.parse_config_dict(config_dict, "", "")
+ return config.rc_federation
diff --git a/tests/utils.py b/tests/utils.py
index da43166f3a..8a94ce0b47 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -34,6 +34,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer
+from synapse.logging.context import LoggingContext
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
@@ -42,7 +43,6 @@ from synapse.storage.prepare_database import (
_setup_new_database,
prepare_database,
)
-from synapse.util.logcontext import LoggingContext
from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
@@ -152,12 +152,6 @@ def default_config(name, parse=False):
"mau_stats_only": False,
"mau_limits_reserved_threepids": [],
"admin_contact": None,
- "rc_federation": {
- "reject_limit": 10,
- "sleep_limit": 10,
- "sleep_delay": 10,
- "concurrent": 10,
- },
"rc_message": {"per_second": 10000, "burst_count": 10000},
"rc_registration": {"per_second": 10000, "burst_count": 10000},
"rc_login": {
|