diff options
Diffstat (limited to 'tests/unittest.py')
-rw-r--r-- | tests/unittest.py | 57 |
1 files changed, 47 insertions, 10 deletions
diff --git a/tests/unittest.py b/tests/unittest.py index a59291cc60..092c930396 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import gc import hashlib import hmac import logging @@ -31,10 +31,12 @@ from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest from synapse.server import HomeServer from synapse.types import UserID, create_requester -from synapse.util.logcontext import LoggingContextFilter +from synapse.util.logcontext import LoggingContext, LoggingContextFilter from tests.server import get_clock, make_request, render, setup_test_homeserver -from tests.utils import default_config +from tests.utils import default_config, setupdb + +setupdb() # Set up putting Synapse's logs into Trial's. rootLogger = logging.getLogger() @@ -102,8 +104,16 @@ class TestCase(unittest.TestCase): # traceback when a unit test exits leaving things on the reactor. twisted.internet.base.DelayedCall.debug = True - old_level = logging.getLogger().level + # if we're not starting in the sentinel logcontext, then to be honest + # all future bets are off. + if LoggingContext.current_context() is not LoggingContext.sentinel: + self.fail( + "Test starting with non-sentinel logging context %s" % ( + LoggingContext.current_context(), + ) + ) + old_level = logging.getLogger().level if old_level != level: @around(self) @@ -115,6 +125,16 @@ class TestCase(unittest.TestCase): logging.getLogger().setLevel(level) return orig() + @around(self) + def tearDown(orig): + ret = orig() + # force a GC to workaround problems with deferreds leaking logcontexts when + # they are GCed (see the logcontext docs) + gc.collect() + LoggingContext.set_current_context(LoggingContext.sentinel) + + return ret + def assertObjectHasAttributes(self, attrs, obj): """Asserts that the given object has each of the attributes given, and that the value of each matches according to assertEquals.""" @@ -146,6 +166,13 @@ def DEBUG(target): return target +def INFO(target): + """A decorator to set the .loglevel attribute to logging.INFO. + Can apply to either a TestCase or an individual test method.""" + target.loglevel = logging.INFO + return target + + class HomeserverTestCase(TestCase): """ A base TestCase that reduces boilerplate for HomeServer-using test cases. @@ -182,11 +209,11 @@ class HomeserverTestCase(TestCase): for servlet in self.servlets: servlet(self.hs, self.resource) - if hasattr(self, "user_id"): - from tests.rest.client.v1.utils import RestHelper + from tests.rest.client.v1.utils import RestHelper - self.helper = RestHelper(self.hs, self.resource, self.user_id) + self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None)) + if hasattr(self, "user_id"): if self.hijack_auth: def get_user_by_access_token(token=None, allow_guest=False): @@ -251,7 +278,13 @@ class HomeserverTestCase(TestCase): """ def make_request( - self, method, path, content=b"", access_token=None, request=SynapseRequest + self, + method, + path, + content=b"", + access_token=None, + request=SynapseRequest, + shorthand=True, ): """ Create a SynapseRequest at the path using the method and containing the @@ -263,6 +296,8 @@ class HomeserverTestCase(TestCase): escaped UTF-8 & spaces and such). content (bytes or dict): The body of the request. JSON-encoded, if a dict. + shorthand: Whether to try and be helpful and prefix the given URL + with the usual REST API path, if it doesn't contain it. Returns: A synapse.http.site.SynapseRequest. @@ -270,7 +305,9 @@ class HomeserverTestCase(TestCase): if isinstance(content, dict): content = json.dumps(content).encode('utf8') - return make_request(method, path, content, access_token, request) + return make_request( + self.reactor, method, path, content, access_token, request, shorthand + ) def render(self, request): """ @@ -373,5 +410,5 @@ class HomeserverTestCase(TestCase): self.render(request) self.assertEqual(channel.code, 200) - access_token = channel.json_body["access_token"].encode('ascii') + access_token = channel.json_body["access_token"] return access_token |