summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
authorSean Quah <8349537+squahtx@users.noreply.github.com>2022-04-05 15:43:52 +0100
committerGitHub <noreply@github.com>2022-04-05 15:43:52 +0100
commit800ba87cc881856adae19ec40485578356398639 (patch)
treecb19370283b7def590239b507c5dc5c72e39af8a /synapse/util
parentMerge branch 'master' into develop (diff)
downloadsynapse-800ba87cc881856adae19ec40485578356398639.tar.xz
Refactor and convert `Linearizer` to async (#12357)
Refactor and convert `Linearizer` to async. This makes a `Linearizer`
cancellation bug easier to fix.

Also refactor to use an async context manager, which eliminates an
unlikely footgun where code that doesn't immediately use the context
manager could forget to release the lock.

Signed-off-by: Sean Quah <seanq@element.io>
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/async_helpers.py144
1 files changed, 67 insertions, 77 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 6a8e844d63..4b2a16a6a9 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -18,7 +18,7 @@ import collections
 import inspect
 import itertools
 import logging
-from contextlib import asynccontextmanager, contextmanager
+from contextlib import asynccontextmanager
 from typing import (
     Any,
     AsyncIterator,
@@ -29,7 +29,6 @@ from typing import (
     Generic,
     Hashable,
     Iterable,
-    Iterator,
     List,
     Optional,
     Set,
@@ -342,7 +341,7 @@ class Linearizer:
 
     Example:
 
-        with await limiter.queue("test_key"):
+        async with limiter.queue("test_key"):
             # do some work.
 
     """
@@ -383,95 +382,53 @@ class Linearizer:
         # non-empty.
         return bool(entry.deferreds)
 
-    def queue(self, key: Hashable) -> defer.Deferred:
-        # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
-        # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
-        # propagated inside inlineCallbacks until Twisted 18.7)
+    def queue(self, key: Hashable) -> AsyncContextManager[None]:
+        @asynccontextmanager
+        async def _ctx_manager() -> AsyncIterator[None]:
+            entry = await self._acquire_lock(key)
+            try:
+                yield
+            finally:
+                self._release_lock(key, entry)
+
+        return _ctx_manager()
+
+    async def _acquire_lock(self, key: Hashable) -> _LinearizerEntry:
+        """Acquires a linearizer lock, waiting if necessary.
+
+        Returns once we have secured the lock.
+        """
         entry = self.key_to_defer.setdefault(
             key, _LinearizerEntry(0, collections.OrderedDict())
         )
 
-        # If the number of things executing is greater than the maximum
-        # then add a deferred to the list of blocked items
-        # When one of the things currently executing finishes it will callback
-        # this item so that it can continue executing.
-        if entry.count >= self.max_count:
-            res = self._await_lock(key)
-        else:
+        if entry.count < self.max_count:
+            # The number of things executing is less than the maximum.
             logger.debug(
                 "Acquired uncontended linearizer lock %r for key %r", self.name, key
             )
             entry.count += 1
-            res = defer.succeed(None)
-
-        # once we successfully get the lock, we need to return a context manager which
-        # will release the lock.
-
-        @contextmanager
-        def _ctx_manager(_: None) -> Iterator[None]:
-            try:
-                yield
-            finally:
-                logger.debug("Releasing linearizer lock %r for key %r", self.name, key)
-
-                # We've finished executing so check if there are any things
-                # blocked waiting to execute and start one of them
-                entry.count -= 1
-
-                if entry.deferreds:
-                    (next_def, _) = entry.deferreds.popitem(last=False)
-
-                    # we need to run the next thing in the sentinel context.
-                    with PreserveLoggingContext():
-                        next_def.callback(None)
-                elif entry.count == 0:
-                    # We were the last thing for this key: remove it from the
-                    # map.
-                    del self.key_to_defer[key]
-
-        res.addCallback(_ctx_manager)
-        return res
-
-    def _await_lock(self, key: Hashable) -> defer.Deferred:
-        """Helper for queue: adds a deferred to the queue
-
-        Assumes that we've already checked that we've reached the limit of the number
-        of lock-holders we allow. Creates a new deferred which is added to the list, and
-        adds some management around cancellations.
-
-        Returns the deferred, which will callback once we have secured the lock.
-
-        """
-        entry = self.key_to_defer[key]
+            return entry
 
+        # Otherwise, the number of things executing is at the maximum and we have to
+        # add a deferred to the list of blocked items.
+        # When one of the things currently executing finishes it will callback
+        # this item so that it can continue executing.
         logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
 
         new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred())
         entry.deferreds[new_defer] = 1
 
-        def cb(_r: None) -> "defer.Deferred[None]":
-            logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
-            entry.count += 1
-
-            # if the code holding the lock completes synchronously, then it
-            # will recursively run the next claimant on the list. That can
-            # relatively rapidly lead to stack exhaustion. This is essentially
-            # the same problem as http://twistedmatrix.com/trac/ticket/9304.
-            #
-            # In order to break the cycle, we add a cheeky sleep(0) here to
-            # ensure that we fall back to the reactor between each iteration.
-            #
-            # (This needs to happen while we hold the lock, and the context manager's exit
-            # code must be synchronous, so this is the only sensible place.)
-            return self._clock.sleep(0)
-
-        def eb(e: Failure) -> Failure:
+        try:
+            await new_defer
+        except Exception as e:
             logger.info("defer %r got err %r", new_defer, e)
             if isinstance(e, CancelledError):
                 logger.debug(
-                    "Cancelling wait for linearizer lock %r for key %r", self.name, key
+                    "Cancelling wait for linearizer lock %r for key %r",
+                    self.name,
+                    key,
                 )
-
             else:
                 logger.warning(
                     "Unexpected exception waiting for linearizer lock %r for key %r",
@@ -481,10 +438,43 @@ class Linearizer:
 
             # we just have to take ourselves back out of the queue.
             del entry.deferreds[new_defer]
-            return e
+            raise
+
+        logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
+        entry.count += 1
 
-        new_defer.addCallbacks(cb, eb)
-        return new_defer
+        # if the code holding the lock completes synchronously, then it
+        # will recursively run the next claimant on the list. That can
+        # relatively rapidly lead to stack exhaustion. This is essentially
+        # the same problem as http://twistedmatrix.com/trac/ticket/9304.
+        #
+        # In order to break the cycle, we add a cheeky sleep(0) here to
+        # ensure that we fall back to the reactor between each iteration.
+        #
+        # This needs to happen while we hold the lock. We could put it on the
+        # exit path, but that would slow down the uncontended case.
+        await self._clock.sleep(0)
+
+        return entry
+
+    def _release_lock(self, key: Hashable, entry: _LinearizerEntry) -> None:
+        """Releases a held linearizer lock."""
+        logger.debug("Releasing linearizer lock %r for key %r", self.name, key)
+
+        # We've finished executing so check if there are any things
+        # blocked waiting to execute and start one of them
+        entry.count -= 1
+
+        if entry.deferreds:
+            (next_def, _) = entry.deferreds.popitem(last=False)
+
+            # we need to run the next thing in the sentinel context.
+            with PreserveLoggingContext():
+                next_def.callback(None)
+        elif entry.count == 0:
+            # We were the last thing for this key: remove it from the
+            # map.
+            del self.key_to_defer[key]
 
 
 class ReadWriteLock: