diff --git a/tests/__init__.py b/tests/__init__.py
index 9d9ca22829..d3181f9403 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +16,9 @@
from twisted.trial import util
-from tests import utils
+import tests.patch_inline_callbacks
+
+# attempt to do the patch before we load any synapse code
+tests.patch_inline_callbacks.do_patch()
util.DEFAULT_TIMEOUT_DURATION = 10
-utils.setupdb()
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 8299dc72c8..d643bec887 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -63,6 +63,14 @@ class KeyringTestCase(unittest.TestCase):
keys = self.mock_perspective_server.get_verify_keys()
self.hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
+ def assert_sentinel_context(self):
+ if LoggingContext.current_context() != LoggingContext.sentinel:
+ self.fail(
+ "Expected sentinel context but got %s" % (
+ LoggingContext.current_context(),
+ )
+ )
+
def check_context(self, _, expected):
self.assertEquals(
getattr(LoggingContext.current_context(), "request", None), expected
@@ -70,8 +78,6 @@ class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_wait_for_previous_lookups(self):
- sentinel_context = LoggingContext.current_context()
-
kr = keyring.Keyring(self.hs)
lookup_1_deferred = defer.Deferred()
@@ -99,8 +105,10 @@ class KeyringTestCase(unittest.TestCase):
["server1"], {"server1": lookup_2_deferred}
)
self.assertFalse(wait_2_deferred.called)
+
# ... so we should have reset the LoggingContext.
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assert_sentinel_context()
+
wait_2_deferred.addBoth(self.check_context, "two")
# let the first lookup complete (in the sentinel context)
@@ -198,8 +206,6 @@ class KeyringTestCase(unittest.TestCase):
json1 = {}
signedjson.sign.sign_json(json1, "server9", key1)
- sentinel_context = LoggingContext.current_context()
-
with LoggingContext("one") as context_one:
context_one.request = "one"
@@ -213,7 +219,7 @@ class KeyringTestCase(unittest.TestCase):
defer = kr.verify_json_for_server("server9", json1)
self.assertFalse(defer.called)
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assert_sentinel_context()
yield defer
self.assertIs(LoggingContext.current_context(), context_one)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 3e9a190727..90a2a76475 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -150,7 +150,6 @@ class RegistrationTestCase(unittest.TestCase):
self.hs.config.auto_join_rooms = [room_alias_str]
res = yield self.handler.register(localpart='jeff')
rooms = yield self.store.get_rooms_for_user(res[0])
-
directory_handler = self.hs.get_handlers().directory_handler
room_alias = RoomAlias.from_string(room_alias_str)
room_id = yield directory_handler.get_association(room_alias)
@@ -184,3 +183,14 @@ class RegistrationTestCase(unittest.TestCase):
res = yield self.handler.register(localpart='jeff')
rooms = yield self.store.get_rooms_for_user(res[0])
self.assertEqual(len(rooms), 0)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_where_no_consent(self):
+ self.hs.config.user_consent_at_registration = True
+ self.hs.config.block_events_without_consent_error = "Error"
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+ res = yield self.handler.register(localpart='jeff')
+ yield self.handler.post_consent_actions(res[0])
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py
new file mode 100644
index 0000000000..0f613945c8
--- /dev/null
+++ b/tests/patch_inline_callbacks.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import functools
+import sys
+
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
+
+
+def do_patch():
+ """
+ Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
+ """
+
+ from synapse.util.logcontext import LoggingContext
+
+ orig_inline_callbacks = defer.inlineCallbacks
+
+ def new_inline_callbacks(f):
+
+ orig = orig_inline_callbacks(f)
+
+ @functools.wraps(f)
+ def wrapped(*args, **kwargs):
+ start_context = LoggingContext.current_context()
+
+ try:
+ res = orig(*args, **kwargs)
+ except Exception:
+ if LoggingContext.current_context() != start_context:
+ err = "%s changed context from %s to %s on exception" % (
+ f, start_context, LoggingContext.current_context()
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ raise
+
+ if not isinstance(res, Deferred) or res.called:
+ if LoggingContext.current_context() != start_context:
+ err = "%s changed context from %s to %s" % (
+ f, start_context, LoggingContext.current_context()
+ )
+ # print the error to stderr because otherwise all we
+ # see in travis-ci is the 500 error
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return res
+
+ if LoggingContext.current_context() != LoggingContext.sentinel:
+ err = (
+ "%s returned incomplete deferred in non-sentinel context "
+ "%s (start was %s)"
+ ) % (
+ f, LoggingContext.current_context(), start_context,
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+
+ def check_ctx(r):
+ if LoggingContext.current_context() != start_context:
+ err = "%s completion of %s changed context from %s to %s" % (
+ "Failure" if isinstance(r, Failure) else "Success",
+ f, start_context, LoggingContext.current_context(),
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return r
+
+ res.addBoth(check_ctx)
+ return res
+
+ return wrapped
+
+ defer.inlineCallbacks = new_inline_callbacks
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index fd131e3454..ad5e9a612f 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -30,6 +30,7 @@ from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
+from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.module_loader import load_module
from tests import unittest
@@ -113,7 +114,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
d = Deferred()
d.addCallback(write_to)
self.fetches.append((d, destination, path, args))
- return d
+ return make_deferred_yieldable(d)
client = Mock()
client.get_file = get_file
diff --git a/tests/unittest.py b/tests/unittest.py
index a9ce57da9a..092c930396 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import gc
import hashlib
import hmac
import logging
@@ -31,10 +31,12 @@ from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
-from synapse.util.logcontext import LoggingContextFilter
+from synapse.util.logcontext import LoggingContext, LoggingContextFilter
from tests.server import get_clock, make_request, render, setup_test_homeserver
-from tests.utils import default_config
+from tests.utils import default_config, setupdb
+
+setupdb()
# Set up putting Synapse's logs into Trial's.
rootLogger = logging.getLogger()
@@ -102,8 +104,16 @@ class TestCase(unittest.TestCase):
# traceback when a unit test exits leaving things on the reactor.
twisted.internet.base.DelayedCall.debug = True
- old_level = logging.getLogger().level
+ # if we're not starting in the sentinel logcontext, then to be honest
+ # all future bets are off.
+ if LoggingContext.current_context() is not LoggingContext.sentinel:
+ self.fail(
+ "Test starting with non-sentinel logging context %s" % (
+ LoggingContext.current_context(),
+ )
+ )
+ old_level = logging.getLogger().level
if old_level != level:
@around(self)
@@ -115,6 +125,16 @@ class TestCase(unittest.TestCase):
logging.getLogger().setLevel(level)
return orig()
+ @around(self)
+ def tearDown(orig):
+ ret = orig()
+ # force a GC to workaround problems with deferreds leaking logcontexts when
+ # they are GCed (see the logcontext docs)
+ gc.collect()
+ LoggingContext.set_current_context(LoggingContext.sentinel)
+
+ return ret
+
def assertObjectHasAttributes(self, attrs, obj):
"""Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEquals."""
|