summary refs log tree commit diff
path: root/synapse/util/file_consumer.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/file_consumer.py')
-rw-r--r--synapse/util/file_consumer.py48
1 files changed, 29 insertions, 19 deletions
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index e946189f9a..de2adacd70 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -13,10 +13,14 @@
 # limitations under the License.
 
 import queue
+from typing import BinaryIO, Optional, Union, cast
 
 from twisted.internet import threads
+from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IPullProducer, IPushProducer
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import ISynapseReactor
 
 
 class BackgroundFileConsumer:
@@ -24,9 +28,9 @@ class BackgroundFileConsumer:
     and pull producers
 
     Args:
-        file_obj (file): The file like object to write to. Closed when
+        file_obj: The file like object to write to. Closed when
             finished.
-        reactor (twisted.internet.reactor): the Twisted reactor to use
+        reactor: the Twisted reactor to use
     """
 
     # For PushProducers pause if we have this many unwritten slices
@@ -34,13 +38,13 @@ class BackgroundFileConsumer:
     # And resume once the size of the queue is less than this
     _RESUME_ON_QUEUE_SIZE = 2
 
-    def __init__(self, file_obj, reactor):
-        self._file_obj = file_obj
+    def __init__(self, file_obj: BinaryIO, reactor: ISynapseReactor) -> None:
+        self._file_obj: BinaryIO = file_obj
 
-        self._reactor = reactor
+        self._reactor: ISynapseReactor = reactor
 
         # Producer we're registered with
-        self._producer = None
+        self._producer: Optional[Union[IPushProducer, IPullProducer]] = None
 
         # True if PushProducer, false if PullProducer
         self.streaming = False
@@ -51,20 +55,22 @@ class BackgroundFileConsumer:
 
         # Queue of slices of bytes to be written. When producer calls
         # unregister a final None is sent.
-        self._bytes_queue = queue.Queue()
+        self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue()
 
         # Deferred that is resolved when finished writing
-        self._finished_deferred = None
+        self._finished_deferred: Optional[Deferred[None]] = None
 
         # If the _writer thread throws an exception it gets stored here.
-        self._write_exception = None
+        self._write_exception: Optional[Exception] = None
 
-    def registerProducer(self, producer, streaming):
+    def registerProducer(
+        self, producer: Union[IPushProducer, IPullProducer], streaming: bool
+    ) -> None:
         """Part of IConsumer interface
 
         Args:
-            producer (IProducer)
-            streaming (bool): True if push based producer, False if pull
+            producer
+            streaming: True if push based producer, False if pull
                 based.
         """
         if self._producer:
@@ -81,29 +87,33 @@ class BackgroundFileConsumer:
         if not streaming:
             self._producer.resumeProducing()
 
-    def unregisterProducer(self):
+    def unregisterProducer(self) -> None:
         """Part of IProducer interface"""
         self._producer = None
+        assert self._finished_deferred is not None
         if not self._finished_deferred.called:
             self._bytes_queue.put_nowait(None)
 
-    def write(self, bytes):
+    def write(self, write_bytes: bytes) -> None:
         """Part of IProducer interface"""
         if self._write_exception:
             raise self._write_exception
 
+        assert self._finished_deferred is not None
         if self._finished_deferred.called:
             raise Exception("consumer has closed")
 
-        self._bytes_queue.put_nowait(bytes)
+        self._bytes_queue.put_nowait(write_bytes)
 
         # If this is a PushProducer and the queue is getting behind
         # then we pause the producer.
         if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
             self._paused_producer = True
-            self._producer.pauseProducing()
+            assert self._producer is not None
+            # cast safe because `streaming` means this is an IPushProducer
+            cast(IPushProducer, self._producer).pauseProducing()
 
-    def _writer(self):
+    def _writer(self) -> None:
         """This is run in a background thread to write to the file."""
         try:
             while self._producer or not self._bytes_queue.empty():
@@ -130,11 +140,11 @@ class BackgroundFileConsumer:
         finally:
             self._file_obj.close()
 
-    def wait(self):
+    def wait(self) -> "Deferred[None]":
         """Returns a deferred that resolves when finished writing to file"""
         return make_deferred_yieldable(self._finished_deferred)
 
-    def _resume_paused_producer(self):
+    def _resume_paused_producer(self) -> None:
         """Gets called if we should resume producing after being paused"""
         if self._paused_producer and self._producer:
             self._paused_producer = False