summary refs log tree commit diff
path: root/tests/util/test_file_consumer.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-12-02 12:58:56 -0500
committerGitHub <noreply@github.com>2022-12-02 17:58:56 +0000
commitacea4d7a2ff61b5beda420b54a8451088060a8cd (patch)
treee5677ed7cec1ad0dcdbe88566d697ae7bf8b2df2 /tests/util/test_file_consumer.py
parentProperly handle unknown results for the stream change cache. (#14592) (diff)
downloadsynapse-acea4d7a2ff61b5beda420b54a8451088060a8cd.tar.xz
Add missing types to tests.util. (#14597)
Removes files under tests.util from the ignored by list, then
fully types all tests/util/*.py files.
Diffstat (limited to 'tests/util/test_file_consumer.py')
-rw-r--r--tests/util/test_file_consumer.py103
1 files changed, 60 insertions, 43 deletions
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index 3bb4695405..4f3c983c15 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -12,22 +12,28 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import threading
-from io import StringIO
+from io import BytesIO
+from typing import BinaryIO, Generator, Optional, cast
 from unittest.mock import NonCallableMock
 
-from twisted.internet import defer, reactor
+from zope.interface import implementer
+
+from twisted.internet import defer, reactor as _reactor
+from twisted.internet.interfaces import IPullProducer
 
+from synapse.types import ISynapseReactor
 from synapse.util.file_consumer import BackgroundFileConsumer
 
 from tests import unittest
 
+reactor = cast(ISynapseReactor, _reactor)
+
 
 class FileConsumerTests(unittest.TestCase):
     @defer.inlineCallbacks
-    def test_pull_consumer(self):
-        string_file = StringIO()
+    def test_pull_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
+        string_file = BytesIO()
         consumer = BackgroundFileConsumer(string_file, reactor=reactor)
 
         try:
@@ -35,55 +41,57 @@ class FileConsumerTests(unittest.TestCase):
 
             yield producer.register_with_consumer(consumer)
 
-            yield producer.write_and_wait("Foo")
+            yield producer.write_and_wait(b"Foo")
 
-            self.assertEqual(string_file.getvalue(), "Foo")
+            self.assertEqual(string_file.getvalue(), b"Foo")
 
-            yield producer.write_and_wait("Bar")
+            yield producer.write_and_wait(b"Bar")
 
-            self.assertEqual(string_file.getvalue(), "FooBar")
+            self.assertEqual(string_file.getvalue(), b"FooBar")
         finally:
             consumer.unregisterProducer()
 
-        yield consumer.wait()
+        yield consumer.wait()  # type: ignore[misc]
 
         self.assertTrue(string_file.closed)
 
     @defer.inlineCallbacks
-    def test_push_consumer(self):
-        string_file = BlockingStringWrite()
-        consumer = BackgroundFileConsumer(string_file, reactor=reactor)
+    def test_push_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
+        string_file = BlockingBytesWrite()
+        consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
 
         try:
             producer = NonCallableMock(spec_set=[])
 
             consumer.registerProducer(producer, True)
 
-            consumer.write("Foo")
-            yield string_file.wait_for_n_writes(1)
+            consumer.write(b"Foo")
+            yield string_file.wait_for_n_writes(1)  # type: ignore[misc]
 
-            self.assertEqual(string_file.buffer, "Foo")
+            self.assertEqual(string_file.buffer, b"Foo")
 
-            consumer.write("Bar")
-            yield string_file.wait_for_n_writes(2)
+            consumer.write(b"Bar")
+            yield string_file.wait_for_n_writes(2)  # type: ignore[misc]
 
-            self.assertEqual(string_file.buffer, "FooBar")
+            self.assertEqual(string_file.buffer, b"FooBar")
         finally:
             consumer.unregisterProducer()
 
-        yield consumer.wait()
+        yield consumer.wait()  # type: ignore[misc]
 
         self.assertTrue(string_file.closed)
 
     @defer.inlineCallbacks
-    def test_push_producer_feedback(self):
-        string_file = BlockingStringWrite()
-        consumer = BackgroundFileConsumer(string_file, reactor=reactor)
+    def test_push_producer_feedback(
+        self,
+    ) -> Generator["defer.Deferred[object]", object, None]:
+        string_file = BlockingBytesWrite()
+        consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
 
         try:
             producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
 
-            resume_deferred = defer.Deferred()
+            resume_deferred: defer.Deferred = defer.Deferred()
             producer.resumeProducing.side_effect = lambda: resume_deferred.callback(
                 None
             )
@@ -93,65 +101,72 @@ class FileConsumerTests(unittest.TestCase):
             number_writes = 0
             with string_file.write_lock:
                 for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
-                    consumer.write("Foo")
+                    consumer.write(b"Foo")
                     number_writes += 1
 
                 producer.pauseProducing.assert_called_once()
 
-            yield string_file.wait_for_n_writes(number_writes)
+            yield string_file.wait_for_n_writes(number_writes)  # type: ignore[misc]
 
             yield resume_deferred
             producer.resumeProducing.assert_called_once()
         finally:
             consumer.unregisterProducer()
 
-        yield consumer.wait()
+        yield consumer.wait()  # type: ignore[misc]
 
         self.assertTrue(string_file.closed)
 
 
+@implementer(IPullProducer)
 class DummyPullProducer:
-    def __init__(self):
-        self.consumer = None
-        self.deferred = defer.Deferred()
+    def __init__(self) -> None:
+        self.consumer: Optional[BackgroundFileConsumer] = None
+        self.deferred: "defer.Deferred[object]" = defer.Deferred()
 
-    def resumeProducing(self):
+    def resumeProducing(self) -> None:
         d = self.deferred
         self.deferred = defer.Deferred()
         d.callback(None)
 
-    def write_and_wait(self, bytes):
+    def stopProducing(self) -> None:
+        raise RuntimeError("Unexpected call")
+
+    def write_and_wait(self, write_bytes: bytes) -> "defer.Deferred[object]":
+        assert self.consumer is not None
         d = self.deferred
-        self.consumer.write(bytes)
+        self.consumer.write(write_bytes)
         return d
 
-    def register_with_consumer(self, consumer):
+    def register_with_consumer(
+        self, consumer: BackgroundFileConsumer
+    ) -> "defer.Deferred[object]":
         d = self.deferred
         self.consumer = consumer
         self.consumer.registerProducer(self, False)
         return d
 
 
-class BlockingStringWrite:
-    def __init__(self):
-        self.buffer = ""
+class BlockingBytesWrite:
+    def __init__(self) -> None:
+        self.buffer = b""
         self.closed = False
         self.write_lock = threading.Lock()
 
-        self._notify_write_deferred = None
+        self._notify_write_deferred: Optional[defer.Deferred] = None
         self._number_of_writes = 0
 
-    def write(self, bytes):
+    def write(self, write_bytes: bytes) -> None:
         with self.write_lock:
-            self.buffer += bytes
+            self.buffer += write_bytes
             self._number_of_writes += 1
 
         reactor.callFromThread(self._notify_write)
 
-    def close(self):
+    def close(self) -> None:
         self.closed = True
 
-    def _notify_write(self):
+    def _notify_write(self) -> None:
         "Called by write to indicate a write happened"
         with self.write_lock:
             if not self._notify_write_deferred:
@@ -161,7 +176,9 @@ class BlockingStringWrite:
         d.callback(None)
 
     @defer.inlineCallbacks
-    def wait_for_n_writes(self, n):
+    def wait_for_n_writes(
+        self, n: int
+    ) -> Generator["defer.Deferred[object]", object, None]:
         "Wait for n writes to have happened"
         while True:
             with self.write_lock: