diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py
deleted file mode 100644
index a5a767b1ff..0000000000
--- a/tests/util/test_limiter.py
+++ /dev/null
@@ -1,70 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet import defer
-
-from synapse.util.async import Limiter
-
-from tests import unittest
-
-
-class LimiterTestCase(unittest.TestCase):
-
- @defer.inlineCallbacks
- def test_limiter(self):
- limiter = Limiter(3)
-
- key = object()
-
- d1 = limiter.queue(key)
- cm1 = yield d1
-
- d2 = limiter.queue(key)
- cm2 = yield d2
-
- d3 = limiter.queue(key)
- cm3 = yield d3
-
- d4 = limiter.queue(key)
- self.assertFalse(d4.called)
-
- d5 = limiter.queue(key)
- self.assertFalse(d5.called)
-
- with cm1:
- self.assertFalse(d4.called)
- self.assertFalse(d5.called)
-
- self.assertTrue(d4.called)
- self.assertFalse(d5.called)
-
- with cm3:
- self.assertFalse(d5.called)
-
- self.assertTrue(d5.called)
-
- with cm2:
- pass
-
- with (yield d4):
- pass
-
- with (yield d5):
- pass
-
- d6 = limiter.queue(key)
- with (yield d6):
- pass
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index c95907b32c..4729bd5a0a 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +17,7 @@
from six.moves import range
from twisted.internet import defer, reactor
+from twisted.internet.defer import CancelledError
from synapse.util import Clock, logcontext
from synapse.util.async import Linearizer
@@ -65,3 +67,79 @@ class LinearizerTestCase(unittest.TestCase):
func(i)
return func(1000)
+
+ @defer.inlineCallbacks
+ def test_multiple_entries(self):
+ limiter = Linearizer(max_count=3)
+
+ key = object()
+
+ d1 = limiter.queue(key)
+ cm1 = yield d1
+
+ d2 = limiter.queue(key)
+ cm2 = yield d2
+
+ d3 = limiter.queue(key)
+ cm3 = yield d3
+
+ d4 = limiter.queue(key)
+ self.assertFalse(d4.called)
+
+ d5 = limiter.queue(key)
+ self.assertFalse(d5.called)
+
+ with cm1:
+ self.assertFalse(d4.called)
+ self.assertFalse(d5.called)
+
+ cm4 = yield d4
+ self.assertFalse(d5.called)
+
+ with cm3:
+ self.assertFalse(d5.called)
+
+ cm5 = yield d5
+
+ with cm2:
+ pass
+
+ with cm4:
+ pass
+
+ with cm5:
+ pass
+
+ d6 = limiter.queue(key)
+ with (yield d6):
+ pass
+
+ @defer.inlineCallbacks
+ def test_cancellation(self):
+ linearizer = Linearizer()
+
+ key = object()
+
+ d1 = linearizer.queue(key)
+ cm1 = yield d1
+
+ d2 = linearizer.queue(key)
+ self.assertFalse(d2.called)
+
+ d3 = linearizer.queue(key)
+ self.assertFalse(d3.called)
+
+ d2.cancel()
+
+ with cm1:
+ pass
+
+ self.assertTrue(d2.called)
+ try:
+ yield d2
+ self.fail("Expected d2 to raise CancelledError")
+ except CancelledError:
+ pass
+
+ with (yield d3):
+ pass
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index e3897c0d19..65b0f2e6fb 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -141,8 +141,8 @@ class StreamChangeCacheTests(unittest.TestCase):
)
# Query all the entries mid-way through the stream, but include one
- # that doesn't exist in it. We should get back the one that doesn't
- # exist, too.
+ # that doesn't exist in it. We shouldn't get back the one that doesn't
+ # exist.
self.assertEqual(
cache.get_entities_changed(
[
@@ -153,7 +153,7 @@ class StreamChangeCacheTests(unittest.TestCase):
],
stream_pos=2,
),
- set(["bar@baz.net", "user@elsewhere.org", "not@here.website"]),
+ set(["bar@baz.net", "user@elsewhere.org"]),
)
# Query all the entries, but before the first known point. We will get
@@ -178,6 +178,22 @@ class StreamChangeCacheTests(unittest.TestCase):
),
)
+ # Query a subset of the entries mid-way through the stream. We should
+ # only get back the subset.
+ self.assertEqual(
+ cache.get_entities_changed(
+ [
+ "bar@baz.net",
+ ],
+ stream_pos=2,
+ ),
+ set(
+ [
+ "bar@baz.net",
+ ]
+ ),
+ )
+
def test_max_pos(self):
"""
StreamChangeCache.get_max_pos_of_last_change will return the most
|