diff options
Diffstat (limited to 'tests/unittest.py')
-rw-r--r-- | tests/unittest.py | 101 |
1 files changed, 62 insertions, 39 deletions
diff --git a/tests/unittest.py b/tests/unittest.py index a59291cc60..fac254ff10 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import gc import hashlib import hmac import logging @@ -31,36 +31,14 @@ from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest from synapse.server import HomeServer from synapse.types import UserID, create_requester -from synapse.util.logcontext import LoggingContextFilter +from synapse.util.logcontext import LoggingContext from tests.server import get_clock, make_request, render, setup_test_homeserver -from tests.utils import default_config - -# Set up putting Synapse's logs into Trial's. -rootLogger = logging.getLogger() - -log_format = ( - "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s" -) - - -class ToTwistedHandler(logging.Handler): - tx_log = twisted.logger.Logger() - - def emit(self, record): - log_entry = self.format(record) - log_level = record.levelname.lower().replace('warning', 'warn') - self.tx_log.emit( - twisted.logger.LogLevel.levelWithName(log_level), - log_entry.replace("{", r"(").replace("}", r")"), - ) +from tests.test_utils.logging_setup import setup_logging +from tests.utils import default_config, setupdb - -handler = ToTwistedHandler() -formatter = logging.Formatter(log_format) -handler.setFormatter(formatter) -handler.addFilter(LoggingContextFilter(request="")) -rootLogger.addHandler(handler) +setupdb() +setup_logging() def around(target): @@ -94,7 +72,7 @@ class TestCase(unittest.TestCase): method = getattr(self, methodName) - level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR)) + level = getattr(method, "loglevel", getattr(self, "loglevel", None)) @around(self) def setUp(orig): @@ -102,9 +80,17 @@ class TestCase(unittest.TestCase): # traceback when a unit test exits leaving things on the reactor. twisted.internet.base.DelayedCall.debug = True - old_level = logging.getLogger().level + # if we're not starting in the sentinel logcontext, then to be honest + # all future bets are off. + if LoggingContext.current_context() is not LoggingContext.sentinel: + self.fail( + "Test starting with non-sentinel logging context %s" % ( + LoggingContext.current_context(), + ) + ) - if old_level != level: + old_level = logging.getLogger().level + if level is not None and old_level != level: @around(self) def tearDown(orig): @@ -112,9 +98,20 @@ class TestCase(unittest.TestCase): logging.getLogger().setLevel(old_level) return ret - logging.getLogger().setLevel(level) + logging.getLogger().setLevel(level) + return orig() + @around(self) + def tearDown(orig): + ret = orig() + # force a GC to workaround problems with deferreds leaking logcontexts when + # they are GCed (see the logcontext docs) + gc.collect() + LoggingContext.set_current_context(LoggingContext.sentinel) + + return ret + def assertObjectHasAttributes(self, attrs, obj): """Asserts that the given object has each of the attributes given, and that the value of each matches according to assertEquals.""" @@ -146,6 +143,13 @@ def DEBUG(target): return target +def INFO(target): + """A decorator to set the .loglevel attribute to logging.INFO. + Can apply to either a TestCase or an individual test method.""" + target.loglevel = logging.INFO + return target + + class HomeserverTestCase(TestCase): """ A base TestCase that reduces boilerplate for HomeServer-using test cases. @@ -182,11 +186,11 @@ class HomeserverTestCase(TestCase): for servlet in self.servlets: servlet(self.hs, self.resource) - if hasattr(self, "user_id"): - from tests.rest.client.v1.utils import RestHelper + from tests.rest.client.v1.utils import RestHelper - self.helper = RestHelper(self.hs, self.resource, self.user_id) + self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None)) + if hasattr(self, "user_id"): if self.hijack_auth: def get_user_by_access_token(token=None, allow_guest=False): @@ -251,7 +255,13 @@ class HomeserverTestCase(TestCase): """ def make_request( - self, method, path, content=b"", access_token=None, request=SynapseRequest + self, + method, + path, + content=b"", + access_token=None, + request=SynapseRequest, + shorthand=True, ): """ Create a SynapseRequest at the path using the method and containing the @@ -263,6 +273,8 @@ class HomeserverTestCase(TestCase): escaped UTF-8 & spaces and such). content (bytes or dict): The body of the request. JSON-encoded, if a dict. + shorthand: Whether to try and be helpful and prefix the given URL + with the usual REST API path, if it doesn't contain it. Returns: A synapse.http.site.SynapseRequest. @@ -270,7 +282,9 @@ class HomeserverTestCase(TestCase): if isinstance(content, dict): content = json.dumps(content).encode('utf8') - return make_request(method, path, content, access_token, request) + return make_request( + self.reactor, method, path, content, access_token, request, shorthand + ) def render(self, request): """ @@ -296,7 +310,15 @@ class HomeserverTestCase(TestCase): """ kwargs = dict(kwargs) kwargs.update(self._hs_args) - return setup_test_homeserver(self.addCleanup, *args, **kwargs) + hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) + stor = hs.get_datastore() + + # Run the database background updates. + if hasattr(stor, "do_next_background_update"): + while not self.get_success(stor.has_completed_background_updates()): + self.get_success(stor.do_next_background_update(1)) + + return hs def pump(self, by=0.0): """ @@ -336,6 +358,7 @@ class HomeserverTestCase(TestCase): nonce_str += b"\x00admin" else: nonce_str += b"\x00notadmin" + want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str) want_mac = want_mac.hexdigest() @@ -373,5 +396,5 @@ class HomeserverTestCase(TestCase): self.render(request) self.assertEqual(channel.code, 200) - access_token = channel.json_body["access_token"].encode('ascii') + access_token = channel.json_body["access_token"] return access_token |