summary refs log tree commit diff
path: root/tests/util/caches/test_descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util/caches/test_descriptors.py')
-rw-r--r--tests/util/caches/test_descriptors.py30
1 files changed, 16 insertions, 14 deletions
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 5713870f48..4d2b9e0d64 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -22,8 +22,10 @@ from twisted.internet import defer, reactor
 
 from synapse.api.errors import SynapseError
 from synapse.logging.context import (
+    SENTINEL_CONTEXT,
     LoggingContext,
     PreserveLoggingContext,
+    current_context,
     make_deferred_yieldable,
 )
 from synapse.util.caches import descriptors
@@ -194,7 +196,7 @@ class DescriptorTestCase(unittest.TestCase):
             with LoggingContext() as c1:
                 c1.name = "c1"
                 r = yield obj.fn(1)
-                self.assertEqual(LoggingContext.current_context(), c1)
+                self.assertEqual(current_context(), c1)
             return r
 
         def check_result(r):
@@ -204,12 +206,12 @@ class DescriptorTestCase(unittest.TestCase):
 
         # set off a deferred which will do a cache lookup
         d1 = do_lookup()
-        self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+        self.assertEqual(current_context(), SENTINEL_CONTEXT)
         d1.addCallback(check_result)
 
         # and another
         d2 = do_lookup()
-        self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+        self.assertEqual(current_context(), SENTINEL_CONTEXT)
         d2.addCallback(check_result)
 
         # let the lookup complete
@@ -239,14 +241,14 @@ class DescriptorTestCase(unittest.TestCase):
                 try:
                     d = obj.fn(1)
                     self.assertEqual(
-                        LoggingContext.current_context(), LoggingContext.sentinel
+                        current_context(), SENTINEL_CONTEXT,
                     )
                     yield d
                     self.fail("No exception thrown")
                 except SynapseError:
                     pass
 
-                self.assertEqual(LoggingContext.current_context(), c1)
+                self.assertEqual(current_context(), c1)
 
             # the cache should now be empty
             self.assertEqual(len(obj.fn.cache.cache), 0)
@@ -255,7 +257,7 @@ class DescriptorTestCase(unittest.TestCase):
 
         # set off a deferred which will do a cache lookup
         d1 = do_lookup()
-        self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+        self.assertEqual(current_context(), SENTINEL_CONTEXT)
 
         return d1
 
@@ -310,14 +312,14 @@ class DescriptorTestCase(unittest.TestCase):
 
         obj.mock.return_value = ["spam", "eggs"]
         r = obj.fn(1, 2)
-        self.assertEqual(r, ["spam", "eggs"])
+        self.assertEqual(r.result, ["spam", "eggs"])
         obj.mock.assert_called_once_with(1, 2)
         obj.mock.reset_mock()
 
         # a call with different params should call the mock again
         obj.mock.return_value = ["chips"]
         r = obj.fn(1, 3)
-        self.assertEqual(r, ["chips"])
+        self.assertEqual(r.result, ["chips"])
         obj.mock.assert_called_once_with(1, 3)
         obj.mock.reset_mock()
 
@@ -325,9 +327,9 @@ class DescriptorTestCase(unittest.TestCase):
         self.assertEqual(len(obj.fn.cache.cache), 3)
 
         r = obj.fn(1, 2)
-        self.assertEqual(r, ["spam", "eggs"])
+        self.assertEqual(r.result, ["spam", "eggs"])
         r = obj.fn(1, 3)
-        self.assertEqual(r, ["chips"])
+        self.assertEqual(r.result, ["chips"])
         obj.mock.assert_not_called()
 
     def test_cache_iterable_with_sync_exception(self):
@@ -366,10 +368,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
 
             @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
             def list_fn(self, args1, arg2):
-                assert LoggingContext.current_context().request == "c1"
+                assert current_context().request == "c1"
                 # we want this to behave like an asynchronous function
                 yield run_on_reactor()
-                assert LoggingContext.current_context().request == "c1"
+                assert current_context().request == "c1"
                 return self.mock(args1, arg2)
 
         with LoggingContext() as c1:
@@ -377,9 +379,9 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             obj = Cls()
             obj.mock.return_value = {10: "fish", 20: "chips"}
             d1 = obj.list_fn([10, 20], 2)
-            self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+            self.assertEqual(current_context(), SENTINEL_CONTEXT)
             r = yield d1
-            self.assertEqual(LoggingContext.current_context(), c1)
+            self.assertEqual(current_context(), c1)
             obj.mock.assert_called_once_with([10, 20], 2)
             self.assertEqual(r, {10: "fish", 20: "chips"})
             obj.mock.reset_mock()