diff options
-rw-r--r-- | synapse/util/async.py | 45 | ||||
-rw-r--r-- | tests/util/test_limiter.py | 70 |
2 files changed, 115 insertions, 0 deletions
diff --git a/synapse/util/async.py b/synapse/util/async.py index 347fb1e380..2a680fff5e 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -197,6 +197,51 @@ class Linearizer(object): defer.returnValue(_ctx_manager()) +class Limiter(object): + """Limits concurrent access to resources based on a key. Useful to ensure + only a few thing happen at a time on a given resource. + + Example: + + with (yield limiter.queue("test_key")): + # do some work. + + """ + def __init__(self, max_count): + """ + Args: + max_count(int): The maximum number of concurrent access + """ + self.max_count = max_count + self.key_to_defer = {} + + @defer.inlineCallbacks + def queue(self, key): + entry = self.key_to_defer.setdefault(key, [0, []]) + + if entry[0] >= self.max_count: + new_defer = defer.Deferred() + entry[1].append(new_defer) + with PreserveLoggingContext(): + yield new_defer + + entry[0] += 1 + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + entry[0] -= 1 + try: + entry[1].pop(0).callback(None) + except IndexError: + if entry[0] == 0: + self.key_to_defer.pop(key, None) + + defer.returnValue(_ctx_manager()) + + class ReadWriteLock(object): """A deferred style read write lock. diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py new file mode 100644 index 0000000000..9c795d9fdb --- /dev/null +++ b/tests/util/test_limiter.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests import unittest + +from twisted.internet import defer + +from synapse.util.async import Limiter + + +class LimiterTestCase(unittest.TestCase): + + @defer.inlineCallbacks + def test_limiter(self): + limiter = Limiter(3) + + key = object() + + d1 = limiter.queue(key) + cm1 = yield d1 + + d2 = limiter.queue(key) + cm2 = yield d2 + + d3 = limiter.queue(key) + cm3 = yield d3 + + d4 = limiter.queue(key) + self.assertFalse(d4.called) + + d5 = limiter.queue(key) + self.assertFalse(d5.called) + + with cm1: + self.assertFalse(d4.called) + self.assertFalse(d5.called) + + self.assertTrue(d4.called) + self.assertFalse(d5.called) + + with cm3: + self.assertFalse(d5.called) + + self.assertTrue(d5.called) + + with cm2: + pass + + with (yield d4): + pass + + with (yield d5): + pass + + d6 = limiter.queue(key) + with (yield d6): + pass |