diff --git a/tests/unittest.py b/tests/unittest.py
index a59291cc60..fac254ff10 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import gc
import hashlib
import hmac
import logging
@@ -31,36 +31,14 @@ from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
-from synapse.util.logcontext import LoggingContextFilter
+from synapse.util.logcontext import LoggingContext
from tests.server import get_clock, make_request, render, setup_test_homeserver
-from tests.utils import default_config
-
-# Set up putting Synapse's logs into Trial's.
-rootLogger = logging.getLogger()
-
-log_format = (
- "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
-)
-
-
-class ToTwistedHandler(logging.Handler):
- tx_log = twisted.logger.Logger()
-
- def emit(self, record):
- log_entry = self.format(record)
- log_level = record.levelname.lower().replace('warning', 'warn')
- self.tx_log.emit(
- twisted.logger.LogLevel.levelWithName(log_level),
- log_entry.replace("{", r"(").replace("}", r")"),
- )
+from tests.test_utils.logging_setup import setup_logging
+from tests.utils import default_config, setupdb
-
-handler = ToTwistedHandler()
-formatter = logging.Formatter(log_format)
-handler.setFormatter(formatter)
-handler.addFilter(LoggingContextFilter(request=""))
-rootLogger.addHandler(handler)
+setupdb()
+setup_logging()
def around(target):
@@ -94,7 +72,7 @@ class TestCase(unittest.TestCase):
method = getattr(self, methodName)
- level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR))
+ level = getattr(method, "loglevel", getattr(self, "loglevel", None))
@around(self)
def setUp(orig):
@@ -102,9 +80,17 @@ class TestCase(unittest.TestCase):
# traceback when a unit test exits leaving things on the reactor.
twisted.internet.base.DelayedCall.debug = True
- old_level = logging.getLogger().level
+ # if we're not starting in the sentinel logcontext, then to be honest
+ # all future bets are off.
+ if LoggingContext.current_context() is not LoggingContext.sentinel:
+ self.fail(
+ "Test starting with non-sentinel logging context %s" % (
+ LoggingContext.current_context(),
+ )
+ )
- if old_level != level:
+ old_level = logging.getLogger().level
+ if level is not None and old_level != level:
@around(self)
def tearDown(orig):
@@ -112,9 +98,20 @@ class TestCase(unittest.TestCase):
logging.getLogger().setLevel(old_level)
return ret
- logging.getLogger().setLevel(level)
+ logging.getLogger().setLevel(level)
+
return orig()
+ @around(self)
+ def tearDown(orig):
+ ret = orig()
+ # force a GC to workaround problems with deferreds leaking logcontexts when
+ # they are GCed (see the logcontext docs)
+ gc.collect()
+ LoggingContext.set_current_context(LoggingContext.sentinel)
+
+ return ret
+
def assertObjectHasAttributes(self, attrs, obj):
"""Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEquals."""
@@ -146,6 +143,13 @@ def DEBUG(target):
return target
+def INFO(target):
+ """A decorator to set the .loglevel attribute to logging.INFO.
+ Can apply to either a TestCase or an individual test method."""
+ target.loglevel = logging.INFO
+ return target
+
+
class HomeserverTestCase(TestCase):
"""
A base TestCase that reduces boilerplate for HomeServer-using test cases.
@@ -182,11 +186,11 @@ class HomeserverTestCase(TestCase):
for servlet in self.servlets:
servlet(self.hs, self.resource)
- if hasattr(self, "user_id"):
- from tests.rest.client.v1.utils import RestHelper
+ from tests.rest.client.v1.utils import RestHelper
- self.helper = RestHelper(self.hs, self.resource, self.user_id)
+ self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
+ if hasattr(self, "user_id"):
if self.hijack_auth:
def get_user_by_access_token(token=None, allow_guest=False):
@@ -251,7 +255,13 @@ class HomeserverTestCase(TestCase):
"""
def make_request(
- self, method, path, content=b"", access_token=None, request=SynapseRequest
+ self,
+ method,
+ path,
+ content=b"",
+ access_token=None,
+ request=SynapseRequest,
+ shorthand=True,
):
"""
Create a SynapseRequest at the path using the method and containing the
@@ -263,6 +273,8 @@ class HomeserverTestCase(TestCase):
escaped UTF-8 & spaces and such).
content (bytes or dict): The body of the request. JSON-encoded, if
a dict.
+ shorthand: Whether to try and be helpful and prefix the given URL
+ with the usual REST API path, if it doesn't contain it.
Returns:
A synapse.http.site.SynapseRequest.
@@ -270,7 +282,9 @@ class HomeserverTestCase(TestCase):
if isinstance(content, dict):
content = json.dumps(content).encode('utf8')
- return make_request(method, path, content, access_token, request)
+ return make_request(
+ self.reactor, method, path, content, access_token, request, shorthand
+ )
def render(self, request):
"""
@@ -296,7 +310,15 @@ class HomeserverTestCase(TestCase):
"""
kwargs = dict(kwargs)
kwargs.update(self._hs_args)
- return setup_test_homeserver(self.addCleanup, *args, **kwargs)
+ hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
+ stor = hs.get_datastore()
+
+ # Run the database background updates.
+ if hasattr(stor, "do_next_background_update"):
+ while not self.get_success(stor.has_completed_background_updates()):
+ self.get_success(stor.do_next_background_update(1))
+
+ return hs
def pump(self, by=0.0):
"""
@@ -336,6 +358,7 @@ class HomeserverTestCase(TestCase):
nonce_str += b"\x00admin"
else:
nonce_str += b"\x00notadmin"
+
want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str)
want_mac = want_mac.hexdigest()
@@ -373,5 +396,5 @@ class HomeserverTestCase(TestCase):
self.render(request)
self.assertEqual(channel.code, 200)
- access_token = channel.json_body["access_token"].encode('ascii')
+ access_token = channel.json_body["access_token"]
return access_token
|