diff options
Diffstat (limited to 'tests/util')
-rw-r--r-- | tests/util/caches/test_descriptors.py | 151 | ||||
-rw-r--r-- | tests/util/test_dict_cache.py | 35 | ||||
-rw-r--r-- | tests/util/test_expiring_cache.py | 5 | ||||
-rw-r--r-- | tests/util/test_file_consumer.py | 20 | ||||
-rw-r--r-- | tests/util/test_limiter.py | 70 | ||||
-rw-r--r-- | tests/util/test_linearizer.py | 99 | ||||
-rw-r--r-- | tests/util/test_logcontext.py | 31 | ||||
-rw-r--r-- | tests/util/test_logformatter.py | 1 | ||||
-rw-r--r-- | tests/util/test_lrucache.py | 6 | ||||
-rw-r--r-- | tests/util/test_rwlock.py | 13 | ||||
-rw-r--r-- | tests/util/test_snapshot_cache.py | 6 | ||||
-rw-r--r-- | tests/util/test_stream_change_cache.py | 16 | ||||
-rw-r--r-- | tests/util/test_treecache.py | 4 | ||||
-rw-r--r-- | tests/util/test_wheel_timer.py | 4 |
14 files changed, 286 insertions, 175 deletions
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 2516fe40f4..463a737efa 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -13,20 +13,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial import logging +from functools import partial import mock + +from twisted.internet import defer, reactor + from synapse.api.errors import SynapseError -from synapse.util import async from synapse.util import logcontext -from twisted.internet import defer from synapse.util.caches import descriptors + from tests import unittest logger = logging.getLogger(__name__) +def run_on_reactor(): + d = defer.Deferred() + reactor.callLater(0, d.callback, 0) + return logcontext.make_deferred_yieldable(d) + + class CacheTestCase(unittest.TestCase): def test_invalidate_all(self): cache = descriptors.Cache("testcache") @@ -59,12 +67,8 @@ class CacheTestCase(unittest.TestCase): self.assertIsNone(cache.get("key2", None)) # both callbacks should have been callbacked - self.assertTrue( - callback_record[0], "Invalidation callback for key1 not called", - ) - self.assertTrue( - callback_record[1], "Invalidation callback for key2 not called", - ) + self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") + self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") # letting the other lookup complete should do nothing d1.callback("result1") @@ -160,8 +164,7 @@ class DescriptorTestCase(unittest.TestCase): with logcontext.LoggingContext() as c1: c1.name = "c1" r = yield obj.fn(1) - self.assertEqual(logcontext.LoggingContext.current_context(), - c1) + self.assertEqual(logcontext.LoggingContext.current_context(), c1) defer.returnValue(r) def check_result(r): @@ -171,14 +174,18 @@ class DescriptorTestCase(unittest.TestCase): # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(logcontext.LoggingContext.current_context(), - logcontext.LoggingContext.sentinel) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) d1.addCallback(check_result) # and another d2 = do_lookup() - self.assertEqual(logcontext.LoggingContext.current_context(), - logcontext.LoggingContext.sentinel) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) d2.addCallback(check_result) # let the lookup complete @@ -195,7 +202,8 @@ class DescriptorTestCase(unittest.TestCase): def fn(self, arg1): @defer.inlineCallbacks def inner_fn(): - yield async.run_on_reactor() + # we want this to behave like an asynchronous function + yield run_on_reactor() raise SynapseError(400, "blah") return inner_fn() @@ -205,20 +213,26 @@ class DescriptorTestCase(unittest.TestCase): with logcontext.LoggingContext() as c1: c1.name = "c1" try: - yield obj.fn(1) + d = obj.fn(1) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) + yield d self.fail("No exception thrown") except SynapseError: pass - self.assertEqual(logcontext.LoggingContext.current_context(), - c1) + self.assertEqual(logcontext.LoggingContext.current_context(), c1) obj = Cls() # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(logcontext.LoggingContext.current_context(), - logcontext.LoggingContext.sentinel) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) return d1 @@ -259,3 +273,98 @@ class DescriptorTestCase(unittest.TestCase): r = yield obj.fn(2, 3) self.assertEqual(r, 'chips') obj.mock.assert_not_called() + + +class CachedListDescriptorTestCase(unittest.TestCase): + @defer.inlineCallbacks + def test_cache(self): + class Cls(object): + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, arg2): + pass + + @descriptors.cachedList("fn", "args1", inlineCallbacks=True) + def list_fn(self, args1, arg2): + assert logcontext.LoggingContext.current_context().request == "c1" + # we want this to behave like an asynchronous function + yield run_on_reactor() + assert logcontext.LoggingContext.current_context().request == "c1" + defer.returnValue(self.mock(args1, arg2)) + + with logcontext.LoggingContext() as c1: + c1.request = "c1" + obj = Cls() + obj.mock.return_value = {10: 'fish', 20: 'chips'} + d1 = obj.list_fn([10, 20], 2) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) + r = yield d1 + self.assertEqual(logcontext.LoggingContext.current_context(), c1) + obj.mock.assert_called_once_with([10, 20], 2) + self.assertEqual(r, {10: 'fish', 20: 'chips'}) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = {30: 'peas'} + r = yield obj.list_fn([20, 30], 2) + obj.mock.assert_called_once_with([30], 2) + self.assertEqual(r, {20: 'chips', 30: 'peas'}) + obj.mock.reset_mock() + + # all the values should now be cached + r = yield obj.fn(10, 2) + self.assertEqual(r, 'fish') + r = yield obj.fn(20, 2) + self.assertEqual(r, 'chips') + r = yield obj.fn(30, 2) + self.assertEqual(r, 'peas') + r = yield obj.list_fn([10, 20, 30], 2) + obj.mock.assert_not_called() + self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'}) + + @defer.inlineCallbacks + def test_invalidate(self): + """Make sure that invalidation callbacks are called.""" + + class Cls(object): + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, arg2): + pass + + @descriptors.cachedList("fn", "args1", inlineCallbacks=True) + def list_fn(self, args1, arg2): + # we want this to behave like an asynchronous function + yield run_on_reactor() + defer.returnValue(self.mock(args1, arg2)) + + obj = Cls() + invalidate0 = mock.Mock() + invalidate1 = mock.Mock() + + # cache miss + obj.mock.return_value = {10: 'fish', 20: 'chips'} + r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) + obj.mock.assert_called_once_with([10, 20], 2) + self.assertEqual(r1, {10: 'fish', 20: 'chips'}) + obj.mock.reset_mock() + + # cache hit + r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1) + obj.mock.assert_not_called() + self.assertEqual(r2, {10: 'fish', 20: 'chips'}) + + invalidate0.assert_not_called() + invalidate1.assert_not_called() + + # now if we invalidate the keys, both invalidations should get called + obj.fn.invalidate((10, 2)) + invalidate0.assert_called_once() + invalidate1.assert_called_once() diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index bc92f85fa6..34fdc9a43a 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -14,13 +14,12 @@ # limitations under the License. -from tests import unittest - from synapse.util.caches.dictionary_cache import DictionaryCache +from tests import unittest -class DictCacheTestCase(unittest.TestCase): +class DictCacheTestCase(unittest.TestCase): def setUp(self): self.cache = DictionaryCache("foobar") @@ -32,7 +31,7 @@ class DictCacheTestCase(unittest.TestCase): seq = self.cache.sequence test_value = {"test": "test_simple_cache_hit_full"} - self.cache.update(seq, key, test_value, full=True) + self.cache.update(seq, key, test_value) c = self.cache.get(key) self.assertEqual(test_value, c.value) @@ -41,10 +40,8 @@ class DictCacheTestCase(unittest.TestCase): key = "test_simple_cache_hit_partial" seq = self.cache.sequence - test_value = { - "test": "test_simple_cache_hit_partial" - } - self.cache.update(seq, key, test_value, full=True) + test_value = {"test": "test_simple_cache_hit_partial"} + self.cache.update(seq, key, test_value) c = self.cache.get(key, ["test"]) self.assertEqual(test_value, c.value) @@ -53,10 +50,8 @@ class DictCacheTestCase(unittest.TestCase): key = "test_simple_cache_miss_partial" seq = self.cache.sequence - test_value = { - "test": "test_simple_cache_miss_partial" - } - self.cache.update(seq, key, test_value, full=True) + test_value = {"test": "test_simple_cache_miss_partial"} + self.cache.update(seq, key, test_value) c = self.cache.get(key, ["test2"]) self.assertEqual({}, c.value) @@ -70,7 +65,7 @@ class DictCacheTestCase(unittest.TestCase): "test2": "test_simple_cache_hit_miss_partial2", "test3": "test_simple_cache_hit_miss_partial3", } - self.cache.update(seq, key, test_value, full=True) + self.cache.update(seq, key, test_value) c = self.cache.get(key, ["test2"]) self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value) @@ -79,16 +74,12 @@ class DictCacheTestCase(unittest.TestCase): key = "test_simple_cache_hit_miss_partial" seq = self.cache.sequence - test_value_1 = { - "test": "test_simple_cache_hit_miss_partial", - } - self.cache.update(seq, key, test_value_1, full=False) + test_value_1 = {"test": "test_simple_cache_hit_miss_partial"} + self.cache.update(seq, key, test_value_1, fetched_keys=set("test")) seq = self.cache.sequence - test_value_2 = { - "test2": "test_simple_cache_hit_miss_partial2", - } - self.cache.update(seq, key, test_value_2, full=False) + test_value_2 = {"test2": "test_simple_cache_hit_miss_partial2"} + self.cache.update(seq, key, test_value_2, fetched_keys=set("test2")) c = self.cache.get(key) self.assertEqual( @@ -96,5 +87,5 @@ class DictCacheTestCase(unittest.TestCase): "test": "test_simple_cache_hit_miss_partial", "test2": "test_simple_cache_hit_miss_partial2", }, - c.value + c.value, ) diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 31d24adb8b..5cbada4eda 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -14,15 +14,14 @@ # limitations under the License. -from .. import unittest - from synapse.util.caches.expiringcache import ExpiringCache from tests.utils import MockClock +from .. import unittest -class ExpiringCacheTestCase(unittest.TestCase): +class ExpiringCacheTestCase(unittest.TestCase): def test_get_set(self): clock = MockClock() cache = ExpiringCache("test", clock, max_len=1) diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py index d6e1082779..e90e08d1c0 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py @@ -14,23 +14,23 @@ # limitations under the License. -from twisted.internet import defer, reactor +import threading + from mock import NonCallableMock +from six import StringIO + +from twisted.internet import defer, reactor from synapse.util.file_consumer import BackgroundFileConsumer from tests import unittest -from six import StringIO - -import threading class FileConsumerTests(unittest.TestCase): - @defer.inlineCallbacks def test_pull_consumer(self): string_file = StringIO() - consumer = BackgroundFileConsumer(string_file) + consumer = BackgroundFileConsumer(string_file, reactor=reactor) try: producer = DummyPullProducer() @@ -54,7 +54,7 @@ class FileConsumerTests(unittest.TestCase): @defer.inlineCallbacks def test_push_consumer(self): string_file = BlockingStringWrite() - consumer = BackgroundFileConsumer(string_file) + consumer = BackgroundFileConsumer(string_file, reactor=reactor) try: producer = NonCallableMock(spec_set=[]) @@ -80,13 +80,15 @@ class FileConsumerTests(unittest.TestCase): @defer.inlineCallbacks def test_push_producer_feedback(self): string_file = BlockingStringWrite() - consumer = BackgroundFileConsumer(string_file) + consumer = BackgroundFileConsumer(string_file, reactor=reactor) try: producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) resume_deferred = defer.Deferred() - producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None) + producer.resumeProducing.side_effect = lambda: resume_deferred.callback( + None + ) consumer.registerProducer(producer, True) diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py deleted file mode 100644 index 9c795d9fdb..0000000000 --- a/tests/util/test_limiter.py +++ /dev/null @@ -1,70 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from tests import unittest - -from twisted.internet import defer - -from synapse.util.async import Limiter - - -class LimiterTestCase(unittest.TestCase): - - @defer.inlineCallbacks - def test_limiter(self): - limiter = Limiter(3) - - key = object() - - d1 = limiter.queue(key) - cm1 = yield d1 - - d2 = limiter.queue(key) - cm2 = yield d2 - - d3 = limiter.queue(key) - cm3 = yield d3 - - d4 = limiter.queue(key) - self.assertFalse(d4.called) - - d5 = limiter.queue(key) - self.assertFalse(d5.called) - - with cm1: - self.assertFalse(d4.called) - self.assertFalse(d5.called) - - self.assertTrue(d4.called) - self.assertFalse(d5.called) - - with cm3: - self.assertFalse(d5.called) - - self.assertTrue(d5.called) - - with cm2: - pass - - with (yield d4): - pass - - with (yield d5): - pass - - d6 = limiter.queue(key) - with (yield d6): - pass diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index 4865eb4bc6..61a55b461b 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,17 +13,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util import async, logcontext -from tests import unittest - -from twisted.internet import defer -from synapse.util.async import Linearizer from six.moves import range +from twisted.internet import defer, reactor +from twisted.internet.defer import CancelledError -class LinearizerTestCase(unittest.TestCase): +from synapse.util import Clock, logcontext +from synapse.util.async_helpers import Linearizer +from tests import unittest + + +class LinearizerTestCase(unittest.TestCase): @defer.inlineCallbacks def test_linearizer(self): linearizer = Linearizer() @@ -50,16 +53,90 @@ class LinearizerTestCase(unittest.TestCase): def func(i, sleep=False): with logcontext.LoggingContext("func(%s)" % i) as lc: with (yield linearizer.queue("")): - self.assertEqual( - logcontext.LoggingContext.current_context(), lc) + self.assertEqual(logcontext.LoggingContext.current_context(), lc) if sleep: - yield async.sleep(0) + yield Clock(reactor).sleep(0) - self.assertEqual( - logcontext.LoggingContext.current_context(), lc) + self.assertEqual(logcontext.LoggingContext.current_context(), lc) func(0, sleep=True) for i in range(1, 100): func(i) return func(1000) + + @defer.inlineCallbacks + def test_multiple_entries(self): + limiter = Linearizer(max_count=3) + + key = object() + + d1 = limiter.queue(key) + cm1 = yield d1 + + d2 = limiter.queue(key) + cm2 = yield d2 + + d3 = limiter.queue(key) + cm3 = yield d3 + + d4 = limiter.queue(key) + self.assertFalse(d4.called) + + d5 = limiter.queue(key) + self.assertFalse(d5.called) + + with cm1: + self.assertFalse(d4.called) + self.assertFalse(d5.called) + + cm4 = yield d4 + self.assertFalse(d5.called) + + with cm3: + self.assertFalse(d5.called) + + cm5 = yield d5 + + with cm2: + pass + + with cm4: + pass + + with cm5: + pass + + d6 = limiter.queue(key) + with (yield d6): + pass + + @defer.inlineCallbacks + def test_cancellation(self): + linearizer = Linearizer() + + key = object() + + d1 = linearizer.queue(key) + cm1 = yield d1 + + d2 = linearizer.queue(key) + self.assertFalse(d2.called) + + d3 = linearizer.queue(key) + self.assertFalse(d3.called) + + d2.cancel() + + with cm1: + pass + + self.assertTrue(d2.called) + try: + yield d2 + self.fail("Expected d2 to raise CancelledError") + except CancelledError: + pass + + with (yield d3): + pass diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index ad78d884e0..4633db77b3 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -1,19 +1,15 @@ import twisted.python.failure -from twisted.internet import defer -from twisted.internet import reactor -from .. import unittest +from twisted.internet import defer, reactor -from synapse.util.async import sleep -from synapse.util import logcontext +from synapse.util import Clock, logcontext from synapse.util.logcontext import LoggingContext +from .. import unittest + class LoggingContextTestCase(unittest.TestCase): - def _check_test_key(self, value): - self.assertEquals( - LoggingContext.current_context().request, value - ) + self.assertEquals(LoggingContext.current_context().request, value) def test_with_context(self): with LoggingContext() as context_one: @@ -22,18 +18,20 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_sleep(self): + clock = Clock(reactor) + @defer.inlineCallbacks def competing_callback(): with LoggingContext() as competing_context: competing_context.request = "competing" - yield sleep(0) + yield clock.sleep(0) self._check_test_key("competing") reactor.callLater(0, competing_callback) with LoggingContext() as context_one: context_one.request = "one" - yield sleep(0) + yield clock.sleep(0) self._check_test_key("one") def _test_run_in_background(self, function): @@ -49,6 +47,7 @@ class LoggingContextTestCase(unittest.TestCase): self._check_test_key("one") callback_completed[0] = True return res + d.addCallback(cb) return d @@ -73,8 +72,7 @@ class LoggingContextTestCase(unittest.TestCase): # make sure that the context was reset before it got thrown back # into the reactor try: - self.assertIs(LoggingContext.current_context(), - sentinel_context) + self.assertIs(LoggingContext.current_context(), sentinel_context) d2.callback(None) except BaseException: d2.errback(twisted.python.failure.Failure()) @@ -87,7 +85,7 @@ class LoggingContextTestCase(unittest.TestCase): def test_run_in_background_with_blocking_fn(self): @defer.inlineCallbacks def blocking_function(): - yield sleep(0) + yield Clock(reactor).sleep(0) return self._test_run_in_background(blocking_function) @@ -103,9 +101,7 @@ class LoggingContextTestCase(unittest.TestCase): # a function which returns a deferred which looks like it has been # called, but is actually paused def testfunc(): - return logcontext.make_deferred_yieldable( - _chained_deferred_function() - ) + return logcontext.make_deferred_yieldable(_chained_deferred_function()) return self._test_run_in_background(testfunc) @@ -174,5 +170,6 @@ def _chained_deferred_function(): d2 = defer.Deferred() reactor.callLater(0, d2.callback, res) return d2 + d.addCallback(cb) return d diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py index 1a1a8412f2..297aebbfbe 100644 --- a/tests/util/test_logformatter.py +++ b/tests/util/test_logformatter.py @@ -15,6 +15,7 @@ import sys from synapse.util.logformatter import LogFormatter + from tests import unittest diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index dfb78cb8bd..786947375d 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -14,16 +14,15 @@ # limitations under the License. -from .. import unittest +from mock import Mock from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache -from mock import Mock +from .. import unittest class LruCacheTestCase(unittest.TestCase): - def test_get_set(self): cache = LruCache(1) cache["key"] = "value" @@ -235,7 +234,6 @@ class LruCacheCallbacksTestCase(unittest.TestCase): class LruCacheSizedTestCase(unittest.TestCase): - def test_evict(self): cache = LruCache(5, size_callback=len) cache["key1"] = [0] diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index 1d745ae1a7..bd32e2cee7 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -14,13 +14,12 @@ # limitations under the License. -from tests import unittest +from synapse.util.async_helpers import ReadWriteLock -from synapse.util.async import ReadWriteLock +from tests import unittest class ReadWriteLockTestCase(unittest.TestCase): - def _assert_called_before_not_after(self, lst, first_false): for i, d in enumerate(lst[:first_false]): self.assertTrue(d.called, msg="%d was unexpectedly false" % i) @@ -36,12 +35,12 @@ class ReadWriteLockTestCase(unittest.TestCase): key = object() ds = [ - rwlock.read(key), # 0 - rwlock.read(key), # 1 + rwlock.read(key), # 0 + rwlock.read(key), # 1 rwlock.write(key), # 2 rwlock.write(key), # 3 - rwlock.read(key), # 4 - rwlock.read(key), # 5 + rwlock.read(key), # 4 + rwlock.read(key), # 5 rwlock.write(key), # 6 ] diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py index d3a8630c2f..1a44f72425 100644 --- a/tests/util/test_snapshot_cache.py +++ b/tests/util/test_snapshot_cache.py @@ -14,14 +14,14 @@ # limitations under the License. -from .. import unittest +from twisted.internet.defer import Deferred from synapse.util.caches.snapshot_cache import SnapshotCache -from twisted.internet.defer import Deferred +from .. import unittest -class SnapshotCacheTestCase(unittest.TestCase): +class SnapshotCacheTestCase(unittest.TestCase): def setUp(self): self.cache = SnapshotCache() self.cache.DURATION_MS = 1 diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 67ece166c7..f2be63706b 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -1,8 +1,9 @@ -from tests import unittest from mock import patch from synapse.util.caches.stream_change_cache import StreamChangeCache +from tests import unittest + class StreamChangeCacheTests(unittest.TestCase): """ @@ -140,8 +141,8 @@ class StreamChangeCacheTests(unittest.TestCase): ) # Query all the entries mid-way through the stream, but include one - # that doesn't exist in it. We should get back the one that doesn't - # exist, too. + # that doesn't exist in it. We shouldn't get back the one that doesn't + # exist. self.assertEqual( cache.get_entities_changed( [ @@ -152,7 +153,7 @@ class StreamChangeCacheTests(unittest.TestCase): ], stream_pos=2, ), - set(["bar@baz.net", "user@elsewhere.org", "not@here.website"]), + set(["bar@baz.net", "user@elsewhere.org"]), ) # Query all the entries, but before the first known point. We will get @@ -177,6 +178,13 @@ class StreamChangeCacheTests(unittest.TestCase): ), ) + # Query a subset of the entries mid-way through the stream. We should + # only get back the subset. + self.assertEqual( + cache.get_entities_changed(["bar@baz.net"], stream_pos=2), + set(["bar@baz.net"]), + ) + def test_max_pos(self): """ StreamChangeCache.get_max_pos_of_last_change will return the most diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py index 7ab578a185..a5f2261208 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py @@ -14,10 +14,10 @@ # limitations under the License. -from .. import unittest - from synapse.util.caches.treecache import TreeCache +from .. import unittest + class TreeCacheTestCase(unittest.TestCase): def test_get_set_onelevel(self): diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py index fdb24a48b0..03201a4d9b 100644 --- a/tests/util/test_wheel_timer.py +++ b/tests/util/test_wheel_timer.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .. import unittest - from synapse.util.wheel_timer import WheelTimer +from .. import unittest + class WheelTimerTestCase(unittest.TestCase): def test_single_insert_fetch(self): |