diff options
-rw-r--r-- | synapse/util/async.py | 82 | ||||
-rw-r--r-- | tests/util/test_rwlock.py | 85 |
2 files changed, 167 insertions, 0 deletions
diff --git a/synapse/util/async.py b/synapse/util/async.py index 40be7fe7e3..c84b23ff46 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -194,3 +194,85 @@ class Linearizer(object): self.key_to_defer.pop(key, None) defer.returnValue(_ctx_manager()) + + +class ReadWriteLock(object): + """A deferred style read write lock. + + Example: + + with (yield read_write_lock.read("test_key")): + # do some work + """ + + # IMPLEMENTATION NOTES + # + # We track the most recent queued reader and writer deferreds (which get + # resolved when they release the lock). + # + # Read: We know its safe to acquire a read lock when the latest writer has + # been resolved. The new reader is appeneded to the list of latest readers. + # + # Write: We know its safe to acquire the write lock when both the latest + # writers and readers have been resolved. The new writer replaces the latest + # writer. + + def __init__(self): + # Latest readers queued + self.key_to_current_readers = {} + + # Latest writer queued + self.key_to_current_writer = {} + + @defer.inlineCallbacks + def read(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.setdefault(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + curr_readers.add(new_defer) + + # We wait for the latest writer to finish writing. We can safely ignore + # any existing readers... as they're readers. + yield curr_writer + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + self.key_to_current_readers.get(key, set()).discard(new_defer) + + defer.returnValue(_ctx_manager()) + + @defer.inlineCallbacks + def write(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.get(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + # We wait on all latest readers and writer. + to_wait_on = list(curr_readers) + if curr_writer: + to_wait_on.append(curr_writer) + + # We can clear the list of current readers since the new writer waits + # for them to finish. + curr_readers.clear() + self.key_to_current_writer[key] = new_defer + + yield defer.gatherResults(to_wait_on) + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + if self.key_to_current_writer[key] == new_defer: + self.key_to_current_writer.pop(key) + + defer.returnValue(_ctx_manager()) diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py new file mode 100644 index 0000000000..1d745ae1a7 --- /dev/null +++ b/tests/util/test_rwlock.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests import unittest + +from synapse.util.async import ReadWriteLock + + +class ReadWriteLockTestCase(unittest.TestCase): + + def _assert_called_before_not_after(self, lst, first_false): + for i, d in enumerate(lst[:first_false]): + self.assertTrue(d.called, msg="%d was unexpectedly false" % i) + + for i, d in enumerate(lst[first_false:]): + self.assertFalse( + d.called, msg="%d was unexpectedly true" % (i + first_false) + ) + + def test_rwlock(self): + rwlock = ReadWriteLock() + + key = object() + + ds = [ + rwlock.read(key), # 0 + rwlock.read(key), # 1 + rwlock.write(key), # 2 + rwlock.write(key), # 3 + rwlock.read(key), # 4 + rwlock.read(key), # 5 + rwlock.write(key), # 6 + ] + + self._assert_called_before_not_after(ds, 2) + + with ds[0].result: + self._assert_called_before_not_after(ds, 2) + self._assert_called_before_not_after(ds, 2) + + with ds[1].result: + self._assert_called_before_not_after(ds, 2) + self._assert_called_before_not_after(ds, 3) + + with ds[2].result: + self._assert_called_before_not_after(ds, 3) + self._assert_called_before_not_after(ds, 4) + + with ds[3].result: + self._assert_called_before_not_after(ds, 4) + self._assert_called_before_not_after(ds, 6) + + with ds[5].result: + self._assert_called_before_not_after(ds, 6) + self._assert_called_before_not_after(ds, 6) + + with ds[4].result: + self._assert_called_before_not_after(ds, 6) + self._assert_called_before_not_after(ds, 7) + + with ds[6].result: + pass + + d = rwlock.write(key) + self.assertTrue(d.called) + with d.result: + pass + + d = rwlock.read(key) + self.assertTrue(d.called) + with d.result: + pass |