diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/handlers/pagination.py | 8 | ||||
-rw-r--r-- | synapse/util/async_helpers.py | 71 |
2 files changed, 44 insertions, 35 deletions
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 183fabcfc0..60059fec3e 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -350,7 +350,7 @@ class PaginationHandler: """ self._purges_in_progress_by_room.add(room_id) try: - with await self.pagination_lock.write(room_id): + async with self.pagination_lock.write(room_id): await self.storage.purge_events.purge_history( room_id, token, delete_local_events ) @@ -406,7 +406,7 @@ class PaginationHandler: room_id: room to be purged force: set true to skip checking for joined users. """ - with await self.pagination_lock.write(room_id): + async with self.pagination_lock.write(room_id): # first check that we have no users in this room if not force: joined = await self.store.is_host_joined(room_id, self._server_name) @@ -448,7 +448,7 @@ class PaginationHandler: room_token = from_token.room_key - with await self.pagination_lock.read(room_id): + async with self.pagination_lock.read(room_id): ( membership, member_event_id, @@ -615,7 +615,7 @@ class PaginationHandler: self._purges_in_progress_by_room.add(room_id) try: - with await self.pagination_lock.write(room_id): + async with self.pagination_lock.write(room_id): self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN self._delete_by_id[ delete_id diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 69c8c1baa9..6a8e844d63 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -18,9 +18,10 @@ import collections import inspect import itertools import logging -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from typing import ( Any, + AsyncIterator, Awaitable, Callable, Collection, @@ -40,7 +41,7 @@ from typing import ( ) import attr -from typing_extensions import ContextManager, Literal +from typing_extensions import AsyncContextManager, Literal from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -491,7 +492,7 @@ class ReadWriteLock: Example: - with await read_write_lock.read("test_key"): + async with read_write_lock.read("test_key"): # do some work """ @@ -514,22 +515,24 @@ class ReadWriteLock: # Latest writer queued self.key_to_current_writer: Dict[str, defer.Deferred] = {} - async def read(self, key: str) -> ContextManager: - new_defer: "defer.Deferred[None]" = defer.Deferred() + def read(self, key: str) -> AsyncContextManager: + @asynccontextmanager + async def _ctx_manager() -> AsyncIterator[None]: + new_defer: "defer.Deferred[None]" = defer.Deferred() - curr_readers = self.key_to_current_readers.setdefault(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) + curr_readers = self.key_to_current_readers.setdefault(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) - curr_readers.add(new_defer) + curr_readers.add(new_defer) - # We wait for the latest writer to finish writing. We can safely ignore - # any existing readers... as they're readers. - if curr_writer: - await make_deferred_yieldable(curr_writer) - - @contextmanager - def _ctx_manager() -> Iterator[None]: try: + # We wait for the latest writer to finish writing. We can safely ignore + # any existing readers... as they're readers. + # May raise a `CancelledError` if the `Deferred` wrapping us is + # cancelled. The `Deferred` we are waiting on must not be cancelled, + # since we do not own it. + if curr_writer: + await make_deferred_yieldable(stop_cancellation(curr_writer)) yield finally: with PreserveLoggingContext(): @@ -538,29 +541,35 @@ class ReadWriteLock: return _ctx_manager() - async def write(self, key: str) -> ContextManager: - new_defer: "defer.Deferred[None]" = defer.Deferred() + def write(self, key: str) -> AsyncContextManager: + @asynccontextmanager + async def _ctx_manager() -> AsyncIterator[None]: + new_defer: "defer.Deferred[None]" = defer.Deferred() - curr_readers = self.key_to_current_readers.get(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) + curr_readers = self.key_to_current_readers.get(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) - # We wait on all latest readers and writer. - to_wait_on = list(curr_readers) - if curr_writer: - to_wait_on.append(curr_writer) + # We wait on all latest readers and writer. + to_wait_on = list(curr_readers) + if curr_writer: + to_wait_on.append(curr_writer) - # We can clear the list of current readers since the new writer waits - # for them to finish. - curr_readers.clear() - self.key_to_current_writer[key] = new_defer + # We can clear the list of current readers since `new_defer` waits + # for them to finish. + curr_readers.clear() + self.key_to_current_writer[key] = new_defer - await make_deferred_yieldable(defer.gatherResults(to_wait_on)) - - @contextmanager - def _ctx_manager() -> Iterator[None]: + to_wait_on_defer = defer.gatherResults(to_wait_on) try: + # Wait for all current readers and the latest writer to finish. + # May raise a `CancelledError` immediately after the wait if the + # `Deferred` wrapping us is cancelled. We must only release the lock + # once we have acquired it, hence the use of `delay_cancellation` + # rather than `stop_cancellation`. + await make_deferred_yieldable(delay_cancellation(to_wait_on_defer)) yield finally: + # Release the lock. with PreserveLoggingContext(): new_defer.callback(None) # `self.key_to_current_writer[key]` may be missing if there was another |