summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py29
1 files changed, 25 insertions, 4 deletions
diff --git a/tests/unittest.py b/tests/unittest.py
index a9ce57da9a..78d2f740f9 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -13,7 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import gc
 import hashlib
 import hmac
 import logging
@@ -31,10 +31,12 @@ from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest
 from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
-from synapse.util.logcontext import LoggingContextFilter
+from synapse.util.logcontext import LoggingContext, LoggingContextFilter
 
 from tests.server import get_clock, make_request, render, setup_test_homeserver
-from tests.utils import default_config
+from tests.utils import default_config, setupdb
+
+setupdb()
 
 # Set up putting Synapse's logs into Trial's.
 rootLogger = logging.getLogger()
@@ -102,8 +104,16 @@ class TestCase(unittest.TestCase):
             # traceback when a unit test exits leaving things on the reactor.
             twisted.internet.base.DelayedCall.debug = True
 
-            old_level = logging.getLogger().level
+            # if we're not starting in the sentinel logcontext, then to be honest
+            # all future bets are off.
+            if LoggingContext.current_context() is not LoggingContext.sentinel:
+                self.fail(
+                    "Test starting with non-sentinel logging context %s" % (
+                        LoggingContext.current_context(),
+                    )
+                )
 
+            old_level = logging.getLogger().level
             if old_level != level:
 
                 @around(self)
@@ -115,6 +125,16 @@ class TestCase(unittest.TestCase):
             logging.getLogger().setLevel(level)
             return orig()
 
+        @around(self)
+        def tearDown(orig):
+            ret = orig()
+            # force a GC to workaround problems with deferreds leaking logcontexts when
+            # they are GCed (see the logcontext docs)
+            gc.collect()
+            LoggingContext.set_current_context(LoggingContext.sentinel)
+
+            return ret
+
     def assertObjectHasAttributes(self, attrs, obj):
         """Asserts that the given object has each of the attributes given, and
         that the value of each matches according to assertEquals."""
@@ -353,6 +373,7 @@ class HomeserverTestCase(TestCase):
             nonce_str += b"\x00admin"
         else:
             nonce_str += b"\x00notadmin"
+
         want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str)
         want_mac = want_mac.hexdigest()