summary refs log tree commit diff
path: root/synapse/util/async_helpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/async_helpers.py')
-rw-r--r--synapse/util/async_helpers.py71
1 files changed, 40 insertions, 31 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 69c8c1baa9..6a8e844d63 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -18,9 +18,10 @@ import collections
 import inspect
 import itertools
 import logging
-from contextlib import contextmanager
+from contextlib import asynccontextmanager, contextmanager
 from typing import (
     Any,
+    AsyncIterator,
     Awaitable,
     Callable,
     Collection,
@@ -40,7 +41,7 @@ from typing import (
 )
 
 import attr
-from typing_extensions import ContextManager, Literal
+from typing_extensions import AsyncContextManager, Literal
 
 from twisted.internet import defer
 from twisted.internet.defer import CancelledError
@@ -491,7 +492,7 @@ class ReadWriteLock:
 
     Example:
 
-        with await read_write_lock.read("test_key"):
+        async with read_write_lock.read("test_key"):
             # do some work
     """
 
@@ -514,22 +515,24 @@ class ReadWriteLock:
         # Latest writer queued
         self.key_to_current_writer: Dict[str, defer.Deferred] = {}
 
-    async def read(self, key: str) -> ContextManager:
-        new_defer: "defer.Deferred[None]" = defer.Deferred()
+    def read(self, key: str) -> AsyncContextManager:
+        @asynccontextmanager
+        async def _ctx_manager() -> AsyncIterator[None]:
+            new_defer: "defer.Deferred[None]" = defer.Deferred()
 
-        curr_readers = self.key_to_current_readers.setdefault(key, set())
-        curr_writer = self.key_to_current_writer.get(key, None)
+            curr_readers = self.key_to_current_readers.setdefault(key, set())
+            curr_writer = self.key_to_current_writer.get(key, None)
 
-        curr_readers.add(new_defer)
+            curr_readers.add(new_defer)
 
-        # We wait for the latest writer to finish writing. We can safely ignore
-        # any existing readers... as they're readers.
-        if curr_writer:
-            await make_deferred_yieldable(curr_writer)
-
-        @contextmanager
-        def _ctx_manager() -> Iterator[None]:
             try:
+                # We wait for the latest writer to finish writing. We can safely ignore
+                # any existing readers... as they're readers.
+                # May raise a `CancelledError` if the `Deferred` wrapping us is
+                # cancelled. The `Deferred` we are waiting on must not be cancelled,
+                # since we do not own it.
+                if curr_writer:
+                    await make_deferred_yieldable(stop_cancellation(curr_writer))
                 yield
             finally:
                 with PreserveLoggingContext():
@@ -538,29 +541,35 @@ class ReadWriteLock:
 
         return _ctx_manager()
 
-    async def write(self, key: str) -> ContextManager:
-        new_defer: "defer.Deferred[None]" = defer.Deferred()
+    def write(self, key: str) -> AsyncContextManager:
+        @asynccontextmanager
+        async def _ctx_manager() -> AsyncIterator[None]:
+            new_defer: "defer.Deferred[None]" = defer.Deferred()
 
-        curr_readers = self.key_to_current_readers.get(key, set())
-        curr_writer = self.key_to_current_writer.get(key, None)
+            curr_readers = self.key_to_current_readers.get(key, set())
+            curr_writer = self.key_to_current_writer.get(key, None)
 
-        # We wait on all latest readers and writer.
-        to_wait_on = list(curr_readers)
-        if curr_writer:
-            to_wait_on.append(curr_writer)
+            # We wait on all latest readers and writer.
+            to_wait_on = list(curr_readers)
+            if curr_writer:
+                to_wait_on.append(curr_writer)
 
-        # We can clear the list of current readers since the new writer waits
-        # for them to finish.
-        curr_readers.clear()
-        self.key_to_current_writer[key] = new_defer
+            # We can clear the list of current readers since `new_defer` waits
+            # for them to finish.
+            curr_readers.clear()
+            self.key_to_current_writer[key] = new_defer
 
-        await make_deferred_yieldable(defer.gatherResults(to_wait_on))
-
-        @contextmanager
-        def _ctx_manager() -> Iterator[None]:
+            to_wait_on_defer = defer.gatherResults(to_wait_on)
             try:
+                # Wait for all current readers and the latest writer to finish.
+                # May raise a `CancelledError` immediately after the wait if the
+                # `Deferred` wrapping us is cancelled. We must only release the lock
+                # once we have acquired it, hence the use of `delay_cancellation`
+                # rather than `stop_cancellation`.
+                await make_deferred_yieldable(delay_cancellation(to_wait_on_defer))
                 yield
             finally:
+                # Release the lock.
                 with PreserveLoggingContext():
                     new_defer.callback(None)
                 # `self.key_to_current_writer[key]` may be missing if there was another