diff --git a/changelog.d/17884.misc b/changelog.d/17884.misc
new file mode 100644
index 0000000000..9dfa13f853
--- /dev/null
+++ b/changelog.d/17884.misc
@@ -0,0 +1 @@
+Minor speed-up of sliding sync by computing extensions results in parallel.
diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py
index 0c77b52513..077887ec32 100644
--- a/synapse/handlers/sliding_sync/extensions.py
+++ b/synapse/handlers/sliding_sync/extensions.py
@@ -49,7 +49,10 @@ from synapse.types.handlers.sliding_sync import (
SlidingSyncConfig,
SlidingSyncResult,
)
-from synapse.util.async_helpers import concurrently_execute
+from synapse.util.async_helpers import (
+ concurrently_execute,
+ gather_optional_coroutines,
+)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -97,26 +100,26 @@ class SlidingSyncExtensionHandler:
if sync_config.extensions is None:
return SlidingSyncResult.Extensions()
- to_device_response = None
+ to_device_coro = None
if sync_config.extensions.to_device is not None:
- to_device_response = await self.get_to_device_extension_response(
+ to_device_coro = self.get_to_device_extension_response(
sync_config=sync_config,
to_device_request=sync_config.extensions.to_device,
to_token=to_token,
)
- e2ee_response = None
+ e2ee_coro = None
if sync_config.extensions.e2ee is not None:
- e2ee_response = await self.get_e2ee_extension_response(
+ e2ee_coro = self.get_e2ee_extension_response(
sync_config=sync_config,
e2ee_request=sync_config.extensions.e2ee,
to_token=to_token,
from_token=from_token,
)
- account_data_response = None
+ account_data_coro = None
if sync_config.extensions.account_data is not None:
- account_data_response = await self.get_account_data_extension_response(
+ account_data_coro = self.get_account_data_extension_response(
sync_config=sync_config,
previous_connection_state=previous_connection_state,
new_connection_state=new_connection_state,
@@ -127,9 +130,9 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
- receipts_response = None
+ receipts_coro = None
if sync_config.extensions.receipts is not None:
- receipts_response = await self.get_receipts_extension_response(
+ receipts_coro = self.get_receipts_extension_response(
sync_config=sync_config,
previous_connection_state=previous_connection_state,
new_connection_state=new_connection_state,
@@ -141,9 +144,9 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
- typing_response = None
+ typing_coro = None
if sync_config.extensions.typing is not None:
- typing_response = await self.get_typing_extension_response(
+ typing_coro = self.get_typing_extension_response(
sync_config=sync_config,
actual_lists=actual_lists,
actual_room_ids=actual_room_ids,
@@ -153,6 +156,20 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
+ (
+ to_device_response,
+ e2ee_response,
+ account_data_response,
+ receipts_response,
+ typing_response,
+ ) = await gather_optional_coroutines(
+ to_device_coro,
+ e2ee_coro,
+ account_data_coro,
+ receipts_coro,
+ typing_coro,
+ )
+
return SlidingSyncResult.Extensions(
to_device=to_device_response,
e2ee=e2ee_response,
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index ae2b3d11c0..8a2dfeba13 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -37,6 +37,7 @@ import warnings
from types import TracebackType
from typing import (
TYPE_CHECKING,
+ Any,
Awaitable,
Callable,
Optional,
@@ -850,6 +851,45 @@ def run_in_background(
return d
+def run_coroutine_in_background(
+ coroutine: typing.Coroutine[Any, Any, R],
+) -> "defer.Deferred[R]":
+ """Run the coroutine, ensuring that the current context is restored after
+ return from the function, and that the sentinel context is set once the
+ deferred returned by the function completes.
+
+ Useful for wrapping coroutines that you don't yield or await on (for
+ instance because you want to pass it to deferred.gatherResults()).
+
+ This is a special case of `run_in_background` where we can accept a
+ coroutine directly rather than a function. We can do this because coroutines
+ do not run until called, and so calling an async function without awaiting
+ cannot change the log contexts.
+ """
+
+ current = current_context()
+ d = defer.ensureDeferred(coroutine)
+
+ # The function may have reset the context before returning, so
+ # we need to restore it now.
+ ctx = set_current_context(current)
+
+ # The original context will be restored when the deferred
+ # completes, but there is nothing waiting for it, so it will
+ # get leaked into the reactor or some other function which
+ # wasn't expecting it. We therefore need to reset the context
+ # here.
+ #
+ # (If this feels asymmetric, consider it this way: we are
+ # effectively forking a new thread of execution. We are
+ # probably currently within a ``with LoggingContext()`` block,
+ # which is supposed to have a single entry and exit point. But
+ # by spawning off another deferred, we are effectively
+ # adding a new exit point.)
+ d.addBoth(_set_context_cb, ctx)
+ return d
+
+
T = TypeVar("T")
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 8618bb0651..e1eb8a4863 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -51,7 +51,7 @@ from typing import (
)
import attr
-from typing_extensions import Concatenate, Literal, ParamSpec
+from typing_extensions import Concatenate, Literal, ParamSpec, Unpack
from twisted.internet import defer
from twisted.internet.defer import CancelledError
@@ -61,6 +61,7 @@ from twisted.python.failure import Failure
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
+ run_coroutine_in_background,
run_in_background,
)
from synapse.util import Clock
@@ -344,6 +345,7 @@ T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
+T5 = TypeVar("T5")
@overload
@@ -402,6 +404,112 @@ def gather_results( # type: ignore[misc]
return deferred.addCallback(tuple)
+@overload
+async def gather_optional_coroutines(
+ *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]]]],
+) -> Tuple[Optional[T1]]: ...
+
+
+@overload
+async def gather_optional_coroutines(
+ *coroutines: Unpack[
+ Tuple[
+ Optional[Coroutine[Any, Any, T1]],
+ Optional[Coroutine[Any, Any, T2]],
+ ]
+ ],
+) -> Tuple[Optional[T1], Optional[T2]]: ...
+
+
+@overload
+async def gather_optional_coroutines(
+ *coroutines: Unpack[
+ Tuple[
+ Optional[Coroutine[Any, Any, T1]],
+ Optional[Coroutine[Any, Any, T2]],
+ Optional[Coroutine[Any, Any, T3]],
+ ]
+ ],
+) -> Tuple[Optional[T1], Optional[T2], Optional[T3]]: ...
+
+
+@overload
+async def gather_optional_coroutines(
+ *coroutines: Unpack[
+ Tuple[
+ Optional[Coroutine[Any, Any, T1]],
+ Optional[Coroutine[Any, Any, T2]],
+ Optional[Coroutine[Any, Any, T3]],
+ Optional[Coroutine[Any, Any, T4]],
+ ]
+ ],
+) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4]]: ...
+
+
+@overload
+async def gather_optional_coroutines(
+ *coroutines: Unpack[
+ Tuple[
+ Optional[Coroutine[Any, Any, T1]],
+ Optional[Coroutine[Any, Any, T2]],
+ Optional[Coroutine[Any, Any, T3]],
+ Optional[Coroutine[Any, Any, T4]],
+ Optional[Coroutine[Any, Any, T5]],
+ ]
+ ],
+) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ...
+
+
+async def gather_optional_coroutines(
+ *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]],
+) -> Tuple[Optional[T1], ...]:
+ """Helper function that allows waiting on multiple coroutines at once.
+
+ The return value is a tuple of the return values of the coroutines in order.
+
+ If a `None` is passed instead of a coroutine, it will be ignored and a None
+ is returned in the tuple.
+
+ Note: For typechecking we need to have an explicit overload for each
+ distinct number of coroutines passed in. If you see type problems, it's
+ likely because you're using many arguments and you need to add a new
+ overload above.
+ """
+
+ try:
+ results = await make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_coroutine_in_background(coroutine)
+ for coroutine in coroutines
+ if coroutine is not None
+ ],
+ consumeErrors=True,
+ )
+ )
+
+ results_iter = iter(results)
+ return tuple(
+ next(results_iter) if coroutine is not None else None
+ for coroutine in coroutines
+ )
+ except defer.FirstError as dfe:
+ # unwrap the error from defer.gatherResults.
+
+ # The raised exception's traceback only includes func() etc if
+ # the 'await' happens before the exception is thrown - ie if the failure
+ # happens *asynchronously* - otherwise Twisted throws away the traceback as it
+ # could be large.
+ #
+ # We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
+ # we could throw Twisted into the fires of Mordor.
+
+ # suppress exception chaining, because the FirstError doesn't tell us anything
+ # very interesting.
+ assert isinstance(dfe.subFailure.value, BaseException)
+ raise dfe.subFailure.value from None
+
+
@attr.s(slots=True, auto_attribs=True)
class _LinearizerEntry:
# The number of things executing.
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index d82822d00d..350a2b7c8c 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -18,7 +18,7 @@
#
#
import traceback
-from typing import Generator, List, NoReturn, Optional
+from typing import Any, Coroutine, Generator, List, NoReturn, Optional, Tuple, TypeVar
from parameterized import parameterized_class
@@ -39,6 +39,7 @@ from synapse.util.async_helpers import (
ObservableDeferred,
concurrently_execute,
delay_cancellation,
+ gather_optional_coroutines,
stop_cancellation,
timeout_deferred,
)
@@ -46,6 +47,8 @@ from synapse.util.async_helpers import (
from tests.server import get_clock
from tests.unittest import TestCase
+T = TypeVar("T")
+
class ObservableDeferredTest(TestCase):
def test_succeed(self) -> None:
@@ -588,3 +591,106 @@ class AwakenableSleeperTests(TestCase):
sleeper.wake("name")
self.assertTrue(d1.called)
self.assertTrue(d2.called)
+
+
+class GatherCoroutineTests(TestCase):
+ """Tests for `gather_optional_coroutines`"""
+
+ def make_coroutine(self) -> Tuple[Coroutine[Any, Any, T], "defer.Deferred[T]"]:
+ """Returns a coroutine and a deferred that it is waiting on to resolve"""
+
+ d: "defer.Deferred[T]" = defer.Deferred()
+
+ async def inner() -> T:
+ with PreserveLoggingContext():
+ return await d
+
+ return inner(), d
+
+ def test_single(self) -> None:
+ "Test passing in a single coroutine works"
+
+ with LoggingContext("test_ctx") as text_ctx:
+ deferred: "defer.Deferred[None]"
+ coroutine, deferred = self.make_coroutine()
+
+ gather_deferred = defer.ensureDeferred(
+ gather_optional_coroutines(coroutine)
+ )
+
+ # We shouldn't have a result yet, and should be in the sentinel
+ # context.
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ # Resolving the deferred will resolve the coroutine
+ deferred.callback(None)
+
+ # All coroutines have resolved, and so we should have the results
+ result = self.successResultOf(gather_deferred)
+ self.assertEqual(result, (None,))
+
+ # We should be back in the normal context.
+ self.assertEqual(current_context(), text_ctx)
+
+ def test_multiple_resolve(self) -> None:
+ "Test passing in multiple coroutine that all resolve works"
+
+ with LoggingContext("test_ctx") as test_ctx:
+ deferred1: "defer.Deferred[int]"
+ coroutine1, deferred1 = self.make_coroutine()
+ deferred2: "defer.Deferred[str]"
+ coroutine2, deferred2 = self.make_coroutine()
+
+ gather_deferred = defer.ensureDeferred(
+ gather_optional_coroutines(coroutine1, coroutine2)
+ )
+
+ # We shouldn't have a result yet, and should be in the sentinel
+ # context.
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ # Even if we resolve one of the coroutines, we shouldn't have a result
+ # yet
+ deferred2.callback("test")
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ deferred1.callback(1)
+
+ # All coroutines have resolved, and so we should have the results
+ result = self.successResultOf(gather_deferred)
+ self.assertEqual(result, (1, "test"))
+
+ # We should be back in the normal context.
+ self.assertEqual(current_context(), test_ctx)
+
+ def test_multiple_fail(self) -> None:
+ "Test passing in multiple coroutine where one fails does the right thing"
+
+ with LoggingContext("test_ctx") as test_ctx:
+ deferred1: "defer.Deferred[int]"
+ coroutine1, deferred1 = self.make_coroutine()
+ deferred2: "defer.Deferred[str]"
+ coroutine2, deferred2 = self.make_coroutine()
+
+ gather_deferred = defer.ensureDeferred(
+ gather_optional_coroutines(coroutine1, coroutine2)
+ )
+
+ # We shouldn't have a result yet, and should be in the sentinel
+ # context.
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ # Throw an exception in one of the coroutines
+ exc = Exception("test")
+ deferred2.errback(exc)
+
+ # Expect the gather deferred to immediately fail
+ result_exc = self.failureResultOf(gather_deferred)
+ self.assertEqual(result_exc.value, exc)
+
+ # We should be back in the normal context.
+ self.assertEqual(current_context(), test_ctx)
|