summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10017.misc1
-rw-r--r--synapse/util/batching_queue.py153
-rw-r--r--tests/util/test_batching_queue.py169
3 files changed, 323 insertions, 0 deletions
diff --git a/changelog.d/10017.misc b/changelog.d/10017.misc
new file mode 100644
index 0000000000..4777b7fb57
--- /dev/null
+++ b/changelog.d/10017.misc
@@ -0,0 +1 @@
+Add a batching queue implementation.
diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py
new file mode 100644
index 0000000000..44bbb7b1a8
--- /dev/null
+++ b/synapse/util/batching_queue.py
@@ -0,0 +1,153 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import (
+    Awaitable,
+    Callable,
+    Dict,
+    Generic,
+    Hashable,
+    List,
+    Set,
+    Tuple,
+    TypeVar,
+)
+
+from twisted.internet import defer
+
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import Clock
+
+logger = logging.getLogger(__name__)
+
+
+V = TypeVar("V")
+R = TypeVar("R")
+
+
+class BatchingQueue(Generic[V, R]):
+    """A queue that batches up work, calling the provided processing function
+    with all pending work (for a given key).
+
+    The provided processing function will only be called once at a time for each
+    key. It will be called the next reactor tick after `add_to_queue` has been
+    called, and will keep being called until the queue has been drained (for the
+    given key).
+
+    Note that the return value of `add_to_queue` will be the return value of the
+    processing function that processed the given item. This means that the
+    returned value will likely include data for other items that were in the
+    batch.
+    """
+
+    def __init__(
+        self,
+        name: str,
+        clock: Clock,
+        process_batch_callback: Callable[[List[V]], Awaitable[R]],
+    ):
+        self._name = name
+        self._clock = clock
+
+        # The set of keys currently being processed.
+        self._processing_keys = set()  # type: Set[Hashable]
+
+        # The currently pending batch of values by key, with a Deferred to call
+        # with the result of the corresponding `_process_batch_callback` call.
+        self._next_values = {}  # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]]
+
+        # The function to call with batches of values.
+        self._process_batch_callback = process_batch_callback
+
+        LaterGauge(
+            "synapse_util_batching_queue_number_queued",
+            "The number of items waiting in the queue across all keys",
+            labels=("name",),
+            caller=lambda: sum(len(v) for v in self._next_values.values()),
+        )
+
+        LaterGauge(
+            "synapse_util_batching_queue_number_of_keys",
+            "The number of distinct keys that have items queued",
+            labels=("name",),
+            caller=lambda: len(self._next_values),
+        )
+
+    async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
+        """Adds the value to the queue with the given key, returning the result
+        of the processing function for the batch that included the given value.
+
+        The optional `key` argument allows sharding the queue by some key. The
+        queues will then be processed in parallel, i.e. the process batch
+        function will be called in parallel with batched values from a single
+        key.
+        """
+
+        # First we create a defer and add it and the value to the list of
+        # pending items.
+        d = defer.Deferred()
+        self._next_values.setdefault(key, []).append((value, d))
+
+        # If we're not currently processing the key fire off a background
+        # process to start processing.
+        if key not in self._processing_keys:
+            run_as_background_process(self._name, self._process_queue, key)
+
+        return await make_deferred_yieldable(d)
+
+    async def _process_queue(self, key: Hashable) -> None:
+        """A background task to repeatedly pull things off the queue for the
+        given key and call the `self._process_batch_callback` with the values.
+        """
+
+        try:
+            if key in self._processing_keys:
+                return
+
+            self._processing_keys.add(key)
+
+            while True:
+                # We purposefully wait a reactor tick to allow us to batch
+                # together requests that we're about to receive. A common
+                # pattern is to call `add_to_queue` multiple times at once, and
+                # deferring to the next reactor tick allows us to batch all of
+                # those up.
+                await self._clock.sleep(0)
+
+                next_values = self._next_values.pop(key, [])
+                if not next_values:
+                    # We've exhausted the queue.
+                    break
+
+                try:
+                    values = [value for value, _ in next_values]
+                    results = await self._process_batch_callback(values)
+
+                    for _, deferred in next_values:
+                        with PreserveLoggingContext():
+                            deferred.callback(results)
+
+                except Exception as e:
+                    for _, deferred in next_values:
+                        if deferred.called:
+                            continue
+
+                        with PreserveLoggingContext():
+                            deferred.errback(e)
+
+        finally:
+            self._processing_keys.discard(key)
diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py
new file mode 100644
index 0000000000..5def1e56c9
--- /dev/null
+++ b/tests/util/test_batching_queue.py
@@ -0,0 +1,169 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.util.batching_queue import BatchingQueue
+
+from tests.server import get_clock
+from tests.unittest import TestCase
+
+
+class BatchingQueueTestCase(TestCase):
+    def setUp(self):
+        self.clock, hs_clock = get_clock()
+
+        self._pending_calls = []
+        self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
+
+    async def _process_queue(self, values):
+        d = defer.Deferred()
+        self._pending_calls.append((values, d))
+        return await make_deferred_yieldable(d)
+
+    def test_simple(self):
+        """Tests the basic case of calling `add_to_queue` once and having
+        `_process_queue` return.
+        """
+
+        self.assertFalse(self._pending_calls)
+
+        queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo"))
+
+        # The queue should wait a reactor tick before calling the processing
+        # function.
+        self.assertFalse(self._pending_calls)
+        self.assertFalse(queue_d.called)
+
+        # We should see a call to `_process_queue` after a reactor tick.
+        self.clock.pump([0])
+
+        self.assertEqual(len(self._pending_calls), 1)
+        self.assertEqual(self._pending_calls[0][0], ["foo"])
+        self.assertFalse(queue_d.called)
+
+        # Return value of the `_process_queue` should be propagated back.
+        self._pending_calls.pop()[1].callback("bar")
+
+        self.assertEqual(self.successResultOf(queue_d), "bar")
+
+    def test_batching(self):
+        """Test that multiple calls at the same time get batched up into one
+        call to `_process_queue`.
+        """
+
+        self.assertFalse(self._pending_calls)
+
+        queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
+        queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
+
+        self.clock.pump([0])
+
+        # We should see only *one* call to `_process_queue`
+        self.assertEqual(len(self._pending_calls), 1)
+        self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"])
+        self.assertFalse(queue_d1.called)
+        self.assertFalse(queue_d2.called)
+
+        # Return value of the `_process_queue` should be propagated back to both.
+        self._pending_calls.pop()[1].callback("bar")
+
+        self.assertEqual(self.successResultOf(queue_d1), "bar")
+        self.assertEqual(self.successResultOf(queue_d2), "bar")
+
+    def test_queuing(self):
+        """Test that we queue up requests while a `_process_queue` is being
+        called.
+        """
+
+        self.assertFalse(self._pending_calls)
+
+        queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
+        self.clock.pump([0])
+
+        queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
+
+        # We should see only *one* call to `_process_queue`
+        self.assertEqual(len(self._pending_calls), 1)
+        self.assertEqual(self._pending_calls[0][0], ["foo1"])
+        self.assertFalse(queue_d1.called)
+        self.assertFalse(queue_d2.called)
+
+        # Return value of the `_process_queue` should be propagated back to the
+        # first.
+        self._pending_calls.pop()[1].callback("bar1")
+
+        self.assertEqual(self.successResultOf(queue_d1), "bar1")
+        self.assertFalse(queue_d2.called)
+
+        # We should now see a second call to `_process_queue`
+        self.clock.pump([0])
+        self.assertEqual(len(self._pending_calls), 1)
+        self.assertEqual(self._pending_calls[0][0], ["foo2"])
+        self.assertFalse(queue_d2.called)
+
+        # Return value of the `_process_queue` should be propagated back to the
+        # second.
+        self._pending_calls.pop()[1].callback("bar2")
+
+        self.assertEqual(self.successResultOf(queue_d2), "bar2")
+
+    def test_different_keys(self):
+        """Test that calls to different keys get processed in parallel."""
+
+        self.assertFalse(self._pending_calls)
+
+        queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1", key=1))
+        self.clock.pump([0])
+        queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2", key=2))
+        self.clock.pump([0])
+
+        # We queue up another item with key=2 to check that we will keep taking
+        # things off the queue.
+        queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3", key=2))
+
+        # We should see two calls to `_process_queue`
+        self.assertEqual(len(self._pending_calls), 2)
+        self.assertEqual(self._pending_calls[0][0], ["foo1"])
+        self.assertEqual(self._pending_calls[1][0], ["foo2"])
+        self.assertFalse(queue_d1.called)
+        self.assertFalse(queue_d2.called)
+        self.assertFalse(queue_d3.called)
+
+        # Return value of the `_process_queue` should be propagated back to the
+        # first.
+        self._pending_calls.pop(0)[1].callback("bar1")
+
+        self.assertEqual(self.successResultOf(queue_d1), "bar1")
+        self.assertFalse(queue_d2.called)
+        self.assertFalse(queue_d3.called)
+
+        # Return value of the `_process_queue` should be propagated back to the
+        # second.
+        self._pending_calls.pop()[1].callback("bar2")
+
+        self.assertEqual(self.successResultOf(queue_d2), "bar2")
+        self.assertFalse(queue_d3.called)
+
+        # We should now see a call `_pending_calls` for `foo3`
+        self.clock.pump([0])
+        self.assertEqual(len(self._pending_calls), 1)
+        self.assertEqual(self._pending_calls[0][0], ["foo3"])
+        self.assertFalse(queue_d3.called)
+
+        # Return value of the `_process_queue` should be propagated back to the
+        # third deferred.
+        self._pending_calls.pop()[1].callback("bar4")
+
+        self.assertEqual(self.successResultOf(queue_d3), "bar4")