summary refs log tree commit diff
path: root/synapse/logging/_terse_json.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/logging/_terse_json.py')
-rw-r--r--synapse/logging/_terse_json.py107
1 files changed, 78 insertions, 29 deletions
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 76ce7d8808..03934956f4 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -17,25 +17,29 @@
 Log formatters that output terse JSON.
 """
 
+import json
 import sys
+import traceback
 from collections import deque
 from ipaddress import IPv4Address, IPv6Address, ip_address
 from math import floor
-from typing import IO
+from typing import IO, Optional
 
 import attr
-from simplejson import dumps
 from zope.interface import implementer
 
 from twisted.application.internet import ClientService
+from twisted.internet.defer import Deferred
 from twisted.internet.endpoints import (
     HostnameEndpoint,
     TCP4ClientEndpoint,
     TCP6ClientEndpoint,
 )
+from twisted.internet.interfaces import IPushProducer, ITransport
 from twisted.internet.protocol import Factory, Protocol
 from twisted.logger import FileLogObserver, ILogObserver, Logger
-from twisted.python.failure import Failure
+
+_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
 
 
 def flatten_event(event: dict, metadata: dict, include_time: bool = False):
@@ -141,12 +145,50 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb
 
     def formatEvent(_event: dict) -> str:
         flattened = flatten_event(_event, metadata)
-        return dumps(flattened, ensure_ascii=False, separators=(",", ":")) + "\n"
+        return _encoder.encode(flattened) + "\n"
 
     return FileLogObserver(outFile, formatEvent)
 
 
 @attr.s
+@implementer(IPushProducer)
+class LogProducer(object):
+    """
+    An IPushProducer that writes logs from its buffer to its transport when it
+    is resumed.
+
+    Args:
+        buffer: Log buffer to read logs from.
+        transport: Transport to write to.
+    """
+
+    transport = attr.ib(type=ITransport)
+    _buffer = attr.ib(type=deque)
+    _paused = attr.ib(default=False, type=bool, init=False)
+
+    def pauseProducing(self):
+        self._paused = True
+
+    def stopProducing(self):
+        self._paused = True
+        self._buffer = None
+
+    def resumeProducing(self):
+        self._paused = False
+
+        while self._paused is False and (self._buffer and self.transport.connected):
+            try:
+                event = self._buffer.popleft()
+                self.transport.write(_encoder.encode(event).encode("utf8"))
+                self.transport.write(b"\n")
+            except Exception:
+                # Something has gone wrong writing to the transport -- log it
+                # and break out of the while.
+                traceback.print_exc(file=sys.__stderr__)
+                break
+
+
+@attr.s
 @implementer(ILogObserver)
 class TerseJSONToTCPLogObserver(object):
     """
@@ -165,8 +207,9 @@ class TerseJSONToTCPLogObserver(object):
     metadata = attr.ib(type=dict)
     maximum_buffer = attr.ib(type=int)
     _buffer = attr.ib(default=attr.Factory(deque), type=deque)
-    _writer = attr.ib(default=None)
+    _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
     _logger = attr.ib(default=attr.Factory(Logger))
+    _producer = attr.ib(default=None, type=Optional[LogProducer])
 
     def start(self) -> None:
 
@@ -187,38 +230,44 @@ class TerseJSONToTCPLogObserver(object):
         factory = Factory.forProtocol(Protocol)
         self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
         self._service.startService()
+        self._connect()
 
-    def _write_loop(self) -> None:
+    def stop(self):
+        self._service.stopService()
+
+    def _connect(self) -> None:
         """
-        Implement the write loop.
+        Triggers an attempt to connect then write to the remote if not already writing.
         """
-        if self._writer:
+        if self._connection_waiter:
             return
 
-        self._writer = self._service.whenConnected()
+        self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
+
+        @self._connection_waiter.addErrback
+        def fail(r):
+            r.printTraceback(file=sys.__stderr__)
+            self._connection_waiter = None
+            self._connect()
 
-        @self._writer.addBoth
+        @self._connection_waiter.addCallback
         def writer(r):
-            if isinstance(r, Failure):
-                r.printTraceback(file=sys.__stderr__)
-                self._writer = None
-                self.hs.get_reactor().callLater(1, self._write_loop)
+            # We have a connection. If we already have a producer, and its
+            # transport is the same, just trigger a resumeProducing.
+            if self._producer and r.transport is self._producer.transport:
+                self._producer.resumeProducing()
+                self._connection_waiter = None
                 return
 
-            try:
-                for event in self._buffer:
-                    r.transport.write(
-                        dumps(event, ensure_ascii=False, separators=(",", ":")).encode(
-                            "utf8"
-                        )
-                    )
-                    r.transport.write(b"\n")
-                self._buffer.clear()
-            except Exception as e:
-                sys.__stderr__.write("Failed writing out logs with %s\n" % (str(e),))
-
-            self._writer = False
-            self.hs.get_reactor().callLater(1, self._write_loop)
+            # If the producer is still producing, stop it.
+            if self._producer:
+                self._producer.stopProducing()
+
+            # Make a new producer and start it.
+            self._producer = LogProducer(buffer=self._buffer, transport=r.transport)
+            r.transport.registerProducer(self._producer, True)
+            self._producer.resumeProducing()
+            self._connection_waiter = None
 
     def _handle_pressure(self) -> None:
         """
@@ -277,4 +326,4 @@ class TerseJSONToTCPLogObserver(object):
             self._logger.failure("Failed clearing backpressure")
 
         # Try and write immediately.
-        self._write_loop()
+        self._connect()