diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 3f14ab503f..8176a7dabd 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,18 +14,71 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from functools import partial
import mock
+
+from twisted.internet import defer, reactor
+
from synapse.api.errors import SynapseError
-from synapse.util import async
from synapse.util import logcontext
-from twisted.internet import defer
from synapse.util.caches import descriptors
+
from tests import unittest
logger = logging.getLogger(__name__)
+def run_on_reactor():
+ d = defer.Deferred()
+ reactor.callLater(0, d.callback, 0)
+ return logcontext.make_deferred_yieldable(d)
+
+
+class CacheTestCase(unittest.TestCase):
+ def test_invalidate_all(self):
+ cache = descriptors.Cache("testcache")
+
+ callback_record = [False, False]
+
+ def record_callback(idx):
+ callback_record[idx] = True
+
+ # add a couple of pending entries
+ d1 = defer.Deferred()
+ cache.set("key1", d1, partial(record_callback, 0))
+
+ d2 = defer.Deferred()
+ cache.set("key2", d2, partial(record_callback, 1))
+
+ # lookup should return the deferreds
+ self.assertIs(cache.get("key1"), d1)
+ self.assertIs(cache.get("key2"), d2)
+
+ # let one of the lookups complete
+ d2.callback("result2")
+ self.assertEqual(cache.get("key2"), "result2")
+
+ # now do the invalidation
+ cache.invalidate_all()
+
+ # lookup should return none
+ self.assertIsNone(cache.get("key1", None))
+ self.assertIsNone(cache.get("key2", None))
+
+ # both callbacks should have been callbacked
+ self.assertTrue(
+ callback_record[0], "Invalidation callback for key1 not called",
+ )
+ self.assertTrue(
+ callback_record[1], "Invalidation callback for key2 not called",
+ )
+
+ # letting the other lookup complete should do nothing
+ d1.callback("result1")
+ self.assertIsNone(cache.get("key1", None))
+
+
class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
@@ -149,7 +203,8 @@ class DescriptorTestCase(unittest.TestCase):
def fn(self, arg1):
@defer.inlineCallbacks
def inner_fn():
- yield async.run_on_reactor()
+ # we want this to behave like an asynchronous function
+ yield run_on_reactor()
raise SynapseError(400, "blah")
return inner_fn()
@@ -159,7 +214,12 @@ class DescriptorTestCase(unittest.TestCase):
with logcontext.LoggingContext() as c1:
c1.name = "c1"
try:
- yield obj.fn(1)
+ d = obj.fn(1)
+ self.assertEqual(
+ logcontext.LoggingContext.current_context(),
+ logcontext.LoggingContext.sentinel,
+ )
+ yield d
self.fail("No exception thrown")
except SynapseError:
pass
|