summary refs log tree commit diff
path: root/tests/util/test_linearizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util/test_linearizer.py')
-rw-r--r--tests/util/test_linearizer.py99
1 files changed, 88 insertions, 11 deletions
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 4865eb4bc6..61a55b461b 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,17 +13,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.util import async, logcontext
-from tests import unittest
-
-from twisted.internet import defer
 
-from synapse.util.async import Linearizer
 from six.moves import range
 
+from twisted.internet import defer, reactor
+from twisted.internet.defer import CancelledError
 
-class LinearizerTestCase(unittest.TestCase):
+from synapse.util import Clock, logcontext
+from synapse.util.async_helpers import Linearizer
 
+from tests import unittest
+
+
+class LinearizerTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_linearizer(self):
         linearizer = Linearizer()
@@ -50,16 +53,90 @@ class LinearizerTestCase(unittest.TestCase):
         def func(i, sleep=False):
             with logcontext.LoggingContext("func(%s)" % i) as lc:
                 with (yield linearizer.queue("")):
-                    self.assertEqual(
-                        logcontext.LoggingContext.current_context(), lc)
+                    self.assertEqual(logcontext.LoggingContext.current_context(), lc)
                     if sleep:
-                        yield async.sleep(0)
+                        yield Clock(reactor).sleep(0)
 
-                self.assertEqual(
-                    logcontext.LoggingContext.current_context(), lc)
+                self.assertEqual(logcontext.LoggingContext.current_context(), lc)
 
         func(0, sleep=True)
         for i in range(1, 100):
             func(i)
 
         return func(1000)
+
+    @defer.inlineCallbacks
+    def test_multiple_entries(self):
+        limiter = Linearizer(max_count=3)
+
+        key = object()
+
+        d1 = limiter.queue(key)
+        cm1 = yield d1
+
+        d2 = limiter.queue(key)
+        cm2 = yield d2
+
+        d3 = limiter.queue(key)
+        cm3 = yield d3
+
+        d4 = limiter.queue(key)
+        self.assertFalse(d4.called)
+
+        d5 = limiter.queue(key)
+        self.assertFalse(d5.called)
+
+        with cm1:
+            self.assertFalse(d4.called)
+            self.assertFalse(d5.called)
+
+        cm4 = yield d4
+        self.assertFalse(d5.called)
+
+        with cm3:
+            self.assertFalse(d5.called)
+
+        cm5 = yield d5
+
+        with cm2:
+            pass
+
+        with cm4:
+            pass
+
+        with cm5:
+            pass
+
+        d6 = limiter.queue(key)
+        with (yield d6):
+            pass
+
+    @defer.inlineCallbacks
+    def test_cancellation(self):
+        linearizer = Linearizer()
+
+        key = object()
+
+        d1 = linearizer.queue(key)
+        cm1 = yield d1
+
+        d2 = linearizer.queue(key)
+        self.assertFalse(d2.called)
+
+        d3 = linearizer.queue(key)
+        self.assertFalse(d3.called)
+
+        d2.cancel()
+
+        with cm1:
+            pass
+
+        self.assertTrue(d2.called)
+        try:
+            yield d2
+            self.fail("Expected d2 to raise CancelledError")
+        except CancelledError:
+            pass
+
+        with (yield d3):
+            pass