diff options
author | Sean Quah <seanq@element.io> | 2022-02-24 13:56:38 +0000 |
---|---|---|
committer | Sean Quah <seanq@element.io> | 2022-03-08 17:11:51 +0000 |
commit | 7b19bc68ce008d96d6bba242acf88239a40e0404 (patch) | |
tree | 88603187c6c064eee39233146714f7c0621bfe3c | |
parent | Use `ParamSpec` in type hints for `synapse.logging.context` (#12150) (diff) | |
download | synapse-7b19bc68ce008d96d6bba242acf88239a40e0404.tar.xz |
Convert `ReadWriteLock` to use async context managers
Has the side effect of fixing clean up for readers cancelled while waiting. Breaks the assumption that resolution of a writer `Deferred` means that previous readers and writers have completed, which will be fixed in the next commit. Signed-off-by: Sean Quah <seanq@element.io>
-rw-r--r-- | synapse/handlers/pagination.py | 8 | ||||
-rw-r--r-- | synapse/util/async_helpers.py | 31 | ||||
-rw-r--r-- | tests/util/test_rwlock.py | 95 |
3 files changed, 74 insertions, 60 deletions
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 183fabcfc0..60059fec3e 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -350,7 +350,7 @@ class PaginationHandler: """ self._purges_in_progress_by_room.add(room_id) try: - with await self.pagination_lock.write(room_id): + async with self.pagination_lock.write(room_id): await self.storage.purge_events.purge_history( room_id, token, delete_local_events ) @@ -406,7 +406,7 @@ class PaginationHandler: room_id: room to be purged force: set true to skip checking for joined users. """ - with await self.pagination_lock.write(room_id): + async with self.pagination_lock.write(room_id): # first check that we have no users in this room if not force: joined = await self.store.is_host_joined(room_id, self._server_name) @@ -448,7 +448,7 @@ class PaginationHandler: room_token = from_token.room_key - with await self.pagination_lock.read(room_id): + async with self.pagination_lock.read(room_id): ( membership, member_event_id, @@ -615,7 +615,7 @@ class PaginationHandler: self._purges_in_progress_by_room.add(room_id) try: - with await self.pagination_lock.write(room_id): + async with self.pagination_lock.write(room_id): self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN self._delete_by_id[ delete_id diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index a9f67dcbac..12a572cdd6 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -18,9 +18,10 @@ import collections import inspect import itertools import logging -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from typing import ( Any, + AsyncIterator, Awaitable, Callable, Collection, @@ -40,7 +41,7 @@ from typing import ( ) import attr -from typing_extensions import ContextManager, Literal +from typing_extensions import AsyncContextManager, Literal from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -491,7 +492,7 @@ class ReadWriteLock: Example: - with await read_write_lock.read("test_key"): + async with read_write_lock.read("test_key"): # do some work """ @@ -514,7 +515,7 @@ class ReadWriteLock: # Latest writer queued self.key_to_current_writer: Dict[str, defer.Deferred] = {} - async def read(self, key: str) -> ContextManager: + def read(self, key: str) -> AsyncContextManager: new_defer: "defer.Deferred[None]" = defer.Deferred() curr_readers = self.key_to_current_readers.setdefault(key, set()) @@ -522,14 +523,13 @@ class ReadWriteLock: curr_readers.add(new_defer) - # We wait for the latest writer to finish writing. We can safely ignore - # any existing readers... as they're readers. - if curr_writer: - await make_deferred_yieldable(curr_writer) - - @contextmanager - def _ctx_manager() -> Iterator[None]: + @asynccontextmanager + async def _ctx_manager() -> AsyncIterator[None]: try: + # We wait for the latest writer to finish writing. We can safely ignore + # any existing readers... as they're readers. + if curr_writer: + await make_deferred_yieldable(curr_writer) yield finally: with PreserveLoggingContext(): @@ -538,7 +538,7 @@ class ReadWriteLock: return _ctx_manager() - async def write(self, key: str) -> ContextManager: + def write(self, key: str) -> AsyncContextManager: new_defer: "defer.Deferred[None]" = defer.Deferred() curr_readers = self.key_to_current_readers.get(key, set()) @@ -554,11 +554,10 @@ class ReadWriteLock: curr_readers.clear() self.key_to_current_writer[key] = new_defer - await make_deferred_yieldable(defer.gatherResults(to_wait_on)) - - @contextmanager - def _ctx_manager() -> Iterator[None]: + @asynccontextmanager + async def _ctx_manager() -> AsyncIterator[None]: try: + await make_deferred_yieldable(defer.gatherResults(to_wait_on)) yield finally: with PreserveLoggingContext(): diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index 0774625b85..018cd50f51 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import AsyncContextManager, Callable, Tuple + from twisted.internet import defer from twisted.internet.defer import Deferred @@ -32,58 +34,71 @@ class ReadWriteLockTestCase(unittest.TestCase): def test_rwlock(self): rwlock = ReadWriteLock() + key = "key" + + def start_reader_or_writer( + read_or_write: Callable[[str], AsyncContextManager] + ) -> Tuple["Deferred[None]", "Deferred[None]"]: + acquired_d: "Deferred[None]" = Deferred() + release_d: "Deferred[None]" = Deferred() + + async def action(): + async with read_or_write(key): + acquired_d.callback(None) + await release_d - key = object() + defer.ensureDeferred(action()) + return acquired_d, release_d ds = [ - rwlock.read(key), # 0 - rwlock.read(key), # 1 - rwlock.write(key), # 2 - rwlock.write(key), # 3 - rwlock.read(key), # 4 - rwlock.read(key), # 5 - rwlock.write(key), # 6 + start_reader_or_writer(rwlock.read), # 0 + start_reader_or_writer(rwlock.read), # 1 + start_reader_or_writer(rwlock.write), # 2 + start_reader_or_writer(rwlock.write), # 3 + start_reader_or_writer(rwlock.read), # 4 + start_reader_or_writer(rwlock.read), # 5 + start_reader_or_writer(rwlock.write), # 6 ] - ds = [defer.ensureDeferred(d) for d in ds] + # `Deferred`s that resolve when each reader or writer acquires the lock. + acquired_ds = [acquired_d for acquired_d, _release_d in ds] + # `Deferred`s that will trigger the release of locks when resolved. + release_ds = [release_d for _acquired_d, release_d in ds] - self._assert_called_before_not_after(ds, 2) + self._assert_called_before_not_after(acquired_ds, 2) - with ds[0].result: - self._assert_called_before_not_after(ds, 2) - self._assert_called_before_not_after(ds, 2) + self._assert_called_before_not_after(acquired_ds, 2) + release_ds[0].callback(None) + self._assert_called_before_not_after(acquired_ds, 2) - with ds[1].result: - self._assert_called_before_not_after(ds, 2) - self._assert_called_before_not_after(ds, 3) + self._assert_called_before_not_after(acquired_ds, 2) + release_ds[1].callback(None) + self._assert_called_before_not_after(acquired_ds, 3) - with ds[2].result: - self._assert_called_before_not_after(ds, 3) - self._assert_called_before_not_after(ds, 4) + self._assert_called_before_not_after(acquired_ds, 3) + release_ds[2].callback(None) + self._assert_called_before_not_after(acquired_ds, 4) - with ds[3].result: - self._assert_called_before_not_after(ds, 4) - self._assert_called_before_not_after(ds, 6) + self._assert_called_before_not_after(acquired_ds, 4) + release_ds[3].callback(None) + self._assert_called_before_not_after(acquired_ds, 6) - with ds[5].result: - self._assert_called_before_not_after(ds, 6) - self._assert_called_before_not_after(ds, 6) + self._assert_called_before_not_after(acquired_ds, 6) + release_ds[5].callback(None) + self._assert_called_before_not_after(acquired_ds, 6) - with ds[4].result: - self._assert_called_before_not_after(ds, 6) - self._assert_called_before_not_after(ds, 7) + self._assert_called_before_not_after(acquired_ds, 6) + release_ds[4].callback(None) + self._assert_called_before_not_after(acquired_ds, 7) - with ds[6].result: - pass + release_ds[6].callback(None) - d = defer.ensureDeferred(rwlock.write(key)) - self.assertTrue(d.called) - with d.result: - pass + acquired_d, release_d = start_reader_or_writer(rwlock.write) + self.assertTrue(acquired_d.called) + release_d.callback(None) - d = defer.ensureDeferred(rwlock.read(key)) - self.assertTrue(d.called) - with d.result: - pass + acquired_d, release_d = start_reader_or_writer(rwlock.read) + self.assertTrue(acquired_d.called) + release_d.callback(None) def test_lock_handoff_to_nonblocking_writer(self): """Test a writer handing the lock to another writer that completes instantly.""" @@ -93,11 +108,11 @@ class ReadWriteLockTestCase(unittest.TestCase): unblock: "Deferred[None]" = Deferred() async def blocking_write(): - with await rwlock.write(key): + async with rwlock.write(key): await unblock async def nonblocking_write(): - with await rwlock.write(key): + async with rwlock.write(key): pass d1 = defer.ensureDeferred(blocking_write()) |