summary refs log tree commit diff
path: root/tests/util/test_batching_queue.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util/test_batching_queue.py')
-rw-r--r--tests/util/test_batching_queue.py30
1 files changed, 18 insertions, 12 deletions
diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py
index 07be57d72c..94ef91f645 100644
--- a/tests/util/test_batching_queue.py
+++ b/tests/util/test_batching_queue.py
@@ -11,6 +11,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import List, Tuple
+
+from prometheus_client import Gauge
+
 from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable
@@ -26,7 +30,7 @@ from tests.unittest import TestCase
 
 
 class BatchingQueueTestCase(TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.clock, hs_clock = get_clock()
 
         # We ensure that we remove any existing metrics for "test_queue".
@@ -37,25 +41,27 @@ class BatchingQueueTestCase(TestCase):
         except KeyError:
             pass
 
-        self._pending_calls = []
-        self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
+        self._pending_calls: List[Tuple[List[str], defer.Deferred]] = []
+        self.queue: BatchingQueue[str, str] = BatchingQueue(
+            "test_queue", hs_clock, self._process_queue
+        )
 
-    async def _process_queue(self, values):
-        d = defer.Deferred()
+    async def _process_queue(self, values: List[str]) -> str:
+        d: "defer.Deferred[str]" = defer.Deferred()
         self._pending_calls.append((values, d))
         return await make_deferred_yieldable(d)
 
-    def _get_sample_with_name(self, metric, name) -> int:
+    def _get_sample_with_name(self, metric: Gauge, name: str) -> float:
         """For a prometheus metric get the value of the sample that has a
         matching "name" label.
         """
-        for sample in metric.collect()[0].samples:
+        for sample in next(iter(metric.collect())).samples:
             if sample.labels.get("name") == name:
                 return sample.value
 
         self.fail("Found no matching sample")
 
-    def _assert_metrics(self, queued, keys, in_flight):
+    def _assert_metrics(self, queued: int, keys: int, in_flight: int) -> None:
         """Assert that the metrics are correct"""
 
         sample = self._get_sample_with_name(number_queued, self.queue._name)
@@ -75,7 +81,7 @@ class BatchingQueueTestCase(TestCase):
             "number_in_flight",
         )
 
-    def test_simple(self):
+    def test_simple(self) -> None:
         """Tests the basic case of calling `add_to_queue` once and having
         `_process_queue` return.
         """
@@ -106,7 +112,7 @@ class BatchingQueueTestCase(TestCase):
 
         self._assert_metrics(queued=0, keys=0, in_flight=0)
 
-    def test_batching(self):
+    def test_batching(self) -> None:
         """Test that multiple calls at the same time get batched up into one
         call to `_process_queue`.
         """
@@ -134,7 +140,7 @@ class BatchingQueueTestCase(TestCase):
         self.assertEqual(self.successResultOf(queue_d2), "bar")
         self._assert_metrics(queued=0, keys=0, in_flight=0)
 
-    def test_queuing(self):
+    def test_queuing(self) -> None:
         """Test that we queue up requests while a `_process_queue` is being
         called.
         """
@@ -184,7 +190,7 @@ class BatchingQueueTestCase(TestCase):
         self.assertEqual(self.successResultOf(queue_d3), "bar2")
         self._assert_metrics(queued=0, keys=0, in_flight=0)
 
-    def test_different_keys(self):
+    def test_different_keys(self) -> None:
         """Test that calls to different keys get processed in parallel."""
 
         self.assertFalse(self._pending_calls)