summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py101
1 files changed, 62 insertions, 39 deletions
diff --git a/tests/unittest.py b/tests/unittest.py
index a59291cc60..fac254ff10 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -13,7 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import gc
 import hashlib
 import hmac
 import logging
@@ -31,36 +31,14 @@ from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest
 from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
-from synapse.util.logcontext import LoggingContextFilter
+from synapse.util.logcontext import LoggingContext
 
 from tests.server import get_clock, make_request, render, setup_test_homeserver
-from tests.utils import default_config
-
-# Set up putting Synapse's logs into Trial's.
-rootLogger = logging.getLogger()
-
-log_format = (
-    "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
-)
-
-
-class ToTwistedHandler(logging.Handler):
-    tx_log = twisted.logger.Logger()
-
-    def emit(self, record):
-        log_entry = self.format(record)
-        log_level = record.levelname.lower().replace('warning', 'warn')
-        self.tx_log.emit(
-            twisted.logger.LogLevel.levelWithName(log_level),
-            log_entry.replace("{", r"(").replace("}", r")"),
-        )
+from tests.test_utils.logging_setup import setup_logging
+from tests.utils import default_config, setupdb
 
-
-handler = ToTwistedHandler()
-formatter = logging.Formatter(log_format)
-handler.setFormatter(formatter)
-handler.addFilter(LoggingContextFilter(request=""))
-rootLogger.addHandler(handler)
+setupdb()
+setup_logging()
 
 
 def around(target):
@@ -94,7 +72,7 @@ class TestCase(unittest.TestCase):
 
         method = getattr(self, methodName)
 
-        level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR))
+        level = getattr(method, "loglevel", getattr(self, "loglevel", None))
 
         @around(self)
         def setUp(orig):
@@ -102,9 +80,17 @@ class TestCase(unittest.TestCase):
             # traceback when a unit test exits leaving things on the reactor.
             twisted.internet.base.DelayedCall.debug = True
 
-            old_level = logging.getLogger().level
+            # if we're not starting in the sentinel logcontext, then to be honest
+            # all future bets are off.
+            if LoggingContext.current_context() is not LoggingContext.sentinel:
+                self.fail(
+                    "Test starting with non-sentinel logging context %s" % (
+                        LoggingContext.current_context(),
+                    )
+                )
 
-            if old_level != level:
+            old_level = logging.getLogger().level
+            if level is not None and old_level != level:
 
                 @around(self)
                 def tearDown(orig):
@@ -112,9 +98,20 @@ class TestCase(unittest.TestCase):
                     logging.getLogger().setLevel(old_level)
                     return ret
 
-            logging.getLogger().setLevel(level)
+                logging.getLogger().setLevel(level)
+
             return orig()
 
+        @around(self)
+        def tearDown(orig):
+            ret = orig()
+            # force a GC to workaround problems with deferreds leaking logcontexts when
+            # they are GCed (see the logcontext docs)
+            gc.collect()
+            LoggingContext.set_current_context(LoggingContext.sentinel)
+
+            return ret
+
     def assertObjectHasAttributes(self, attrs, obj):
         """Asserts that the given object has each of the attributes given, and
         that the value of each matches according to assertEquals."""
@@ -146,6 +143,13 @@ def DEBUG(target):
     return target
 
 
+def INFO(target):
+    """A decorator to set the .loglevel attribute to logging.INFO.
+    Can apply to either a TestCase or an individual test method."""
+    target.loglevel = logging.INFO
+    return target
+
+
 class HomeserverTestCase(TestCase):
     """
     A base TestCase that reduces boilerplate for HomeServer-using test cases.
@@ -182,11 +186,11 @@ class HomeserverTestCase(TestCase):
         for servlet in self.servlets:
             servlet(self.hs, self.resource)
 
-        if hasattr(self, "user_id"):
-            from tests.rest.client.v1.utils import RestHelper
+        from tests.rest.client.v1.utils import RestHelper
 
-            self.helper = RestHelper(self.hs, self.resource, self.user_id)
+        self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
 
+        if hasattr(self, "user_id"):
             if self.hijack_auth:
 
                 def get_user_by_access_token(token=None, allow_guest=False):
@@ -251,7 +255,13 @@ class HomeserverTestCase(TestCase):
         """
 
     def make_request(
-        self, method, path, content=b"", access_token=None, request=SynapseRequest
+        self,
+        method,
+        path,
+        content=b"",
+        access_token=None,
+        request=SynapseRequest,
+        shorthand=True,
     ):
         """
         Create a SynapseRequest at the path using the method and containing the
@@ -263,6 +273,8 @@ class HomeserverTestCase(TestCase):
             escaped UTF-8 & spaces and such).
             content (bytes or dict): The body of the request. JSON-encoded, if
             a dict.
+            shorthand: Whether to try and be helpful and prefix the given URL
+            with the usual REST API path, if it doesn't contain it.
 
         Returns:
             A synapse.http.site.SynapseRequest.
@@ -270,7 +282,9 @@ class HomeserverTestCase(TestCase):
         if isinstance(content, dict):
             content = json.dumps(content).encode('utf8')
 
-        return make_request(method, path, content, access_token, request)
+        return make_request(
+            self.reactor, method, path, content, access_token, request, shorthand
+        )
 
     def render(self, request):
         """
@@ -296,7 +310,15 @@ class HomeserverTestCase(TestCase):
         """
         kwargs = dict(kwargs)
         kwargs.update(self._hs_args)
-        return setup_test_homeserver(self.addCleanup, *args, **kwargs)
+        hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
+        stor = hs.get_datastore()
+
+        # Run the database background updates.
+        if hasattr(stor, "do_next_background_update"):
+            while not self.get_success(stor.has_completed_background_updates()):
+                self.get_success(stor.do_next_background_update(1))
+
+        return hs
 
     def pump(self, by=0.0):
         """
@@ -336,6 +358,7 @@ class HomeserverTestCase(TestCase):
             nonce_str += b"\x00admin"
         else:
             nonce_str += b"\x00notadmin"
+
         want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str)
         want_mac = want_mac.hexdigest()
 
@@ -373,5 +396,5 @@ class HomeserverTestCase(TestCase):
         self.render(request)
         self.assertEqual(channel.code, 200)
 
-        access_token = channel.json_body["access_token"].encode('ascii')
+        access_token = channel.json_body["access_token"]
         return access_token