diff --git a/tests/unittest.py b/tests/unittest.py
index a59291cc60..092c930396 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import gc
import hashlib
import hmac
import logging
@@ -31,10 +31,12 @@ from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
-from synapse.util.logcontext import LoggingContextFilter
+from synapse.util.logcontext import LoggingContext, LoggingContextFilter
from tests.server import get_clock, make_request, render, setup_test_homeserver
-from tests.utils import default_config
+from tests.utils import default_config, setupdb
+
+setupdb()
# Set up putting Synapse's logs into Trial's.
rootLogger = logging.getLogger()
@@ -102,8 +104,16 @@ class TestCase(unittest.TestCase):
# traceback when a unit test exits leaving things on the reactor.
twisted.internet.base.DelayedCall.debug = True
- old_level = logging.getLogger().level
+ # if we're not starting in the sentinel logcontext, then to be honest
+ # all future bets are off.
+ if LoggingContext.current_context() is not LoggingContext.sentinel:
+ self.fail(
+ "Test starting with non-sentinel logging context %s" % (
+ LoggingContext.current_context(),
+ )
+ )
+ old_level = logging.getLogger().level
if old_level != level:
@around(self)
@@ -115,6 +125,16 @@ class TestCase(unittest.TestCase):
logging.getLogger().setLevel(level)
return orig()
+ @around(self)
+ def tearDown(orig):
+ ret = orig()
+ # force a GC to workaround problems with deferreds leaking logcontexts when
+ # they are GCed (see the logcontext docs)
+ gc.collect()
+ LoggingContext.set_current_context(LoggingContext.sentinel)
+
+ return ret
+
def assertObjectHasAttributes(self, attrs, obj):
"""Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEquals."""
@@ -146,6 +166,13 @@ def DEBUG(target):
return target
+def INFO(target):
+ """A decorator to set the .loglevel attribute to logging.INFO.
+ Can apply to either a TestCase or an individual test method."""
+ target.loglevel = logging.INFO
+ return target
+
+
class HomeserverTestCase(TestCase):
"""
A base TestCase that reduces boilerplate for HomeServer-using test cases.
@@ -182,11 +209,11 @@ class HomeserverTestCase(TestCase):
for servlet in self.servlets:
servlet(self.hs, self.resource)
- if hasattr(self, "user_id"):
- from tests.rest.client.v1.utils import RestHelper
+ from tests.rest.client.v1.utils import RestHelper
- self.helper = RestHelper(self.hs, self.resource, self.user_id)
+ self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
+ if hasattr(self, "user_id"):
if self.hijack_auth:
def get_user_by_access_token(token=None, allow_guest=False):
@@ -251,7 +278,13 @@ class HomeserverTestCase(TestCase):
"""
def make_request(
- self, method, path, content=b"", access_token=None, request=SynapseRequest
+ self,
+ method,
+ path,
+ content=b"",
+ access_token=None,
+ request=SynapseRequest,
+ shorthand=True,
):
"""
Create a SynapseRequest at the path using the method and containing the
@@ -263,6 +296,8 @@ class HomeserverTestCase(TestCase):
escaped UTF-8 & spaces and such).
content (bytes or dict): The body of the request. JSON-encoded, if
a dict.
+ shorthand: Whether to try and be helpful and prefix the given URL
+ with the usual REST API path, if it doesn't contain it.
Returns:
A synapse.http.site.SynapseRequest.
@@ -270,7 +305,9 @@ class HomeserverTestCase(TestCase):
if isinstance(content, dict):
content = json.dumps(content).encode('utf8')
- return make_request(method, path, content, access_token, request)
+ return make_request(
+ self.reactor, method, path, content, access_token, request, shorthand
+ )
def render(self, request):
"""
@@ -373,5 +410,5 @@ class HomeserverTestCase(TestCase):
self.render(request)
self.assertEqual(channel.code, 200)
- access_token = channel.json_body["access_token"].encode('ascii')
+ access_token = channel.json_body["access_token"]
return access_token
|