summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/http/client.py12
-rw-r--r--synapse/replication/tcp/handler.py4
-rw-r--r--synapse/replication/tcp/protocol.py9
-rw-r--r--synapse/replication/tcp/redis.py2
4 files changed, 22 insertions, 5 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py
index d4ab3a2732..1e01e0a9f2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -45,7 +45,9 @@ from twisted.internet.interfaces import (
     IHostResolution,
     IReactorPluggableNameResolver,
     IResolutionReceiver,
+    ITCPTransport,
 )
+from twisted.internet.protocol import connectionDone
 from twisted.internet.task import Cooperator
 from twisted.python.failure import Failure
 from twisted.web._newclient import ResponseDone
@@ -760,6 +762,8 @@ class BodyExceededMaxSize(Exception):
 class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
     """A protocol which immediately errors upon receiving data."""
 
+    transport = None  # type: Optional[ITCPTransport]
+
     def __init__(self, deferred: defer.Deferred):
         self.deferred = deferred
 
@@ -771,18 +775,21 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
             self.deferred.errback(BodyExceededMaxSize())
             # Close the connection (forcefully) since all the data will get
             # discarded anyway.
+            assert self.transport is not None
             self.transport.abortConnection()
 
     def dataReceived(self, data: bytes) -> None:
         self._maybe_fail()
 
-    def connectionLost(self, reason: Failure) -> None:
+    def connectionLost(self, reason: Failure = connectionDone) -> None:
         self._maybe_fail()
 
 
 class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
     """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
 
+    transport = None  # type: Optional[ITCPTransport]
+
     def __init__(
         self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
     ):
@@ -805,9 +812,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
             self.deferred.errback(BodyExceededMaxSize())
             # Close the connection (forcefully) since all the data will get
             # discarded anyway.
+            assert self.transport is not None
             self.transport.abortConnection()
 
-    def connectionLost(self, reason: Failure) -> None:
+    def connectionLost(self, reason: Failure = connectionDone) -> None:
         # If the maximum size was already exceeded, there's nothing to do.
         if self.deferred.called:
             return
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index ee909f3fc5..a8894beadf 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -302,7 +302,7 @@ class ReplicationCommandHandler:
                 hs, outbound_redis_connection
             )
             hs.get_reactor().connectTCP(
-                hs.config.redis.redis_host,
+                hs.config.redis.redis_host.encode(),
                 hs.config.redis.redis_port,
                 self._factory,
             )
@@ -311,7 +311,7 @@ class ReplicationCommandHandler:
             self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
             host = hs.config.worker_replication_host
             port = hs.config.worker_replication_port
-            hs.get_reactor().connectTCP(host, port, self._factory)
+            hs.get_reactor().connectTCP(host.encode(), port, self._factory)
 
     def get_streams(self) -> Dict[str, Stream]:
         """Get a map from stream name to all streams."""
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 8e4734b59c..825900f64c 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -56,6 +56,7 @@ from prometheus_client import Counter
 from zope.interface import Interface, implementer
 
 from twisted.internet import task
+from twisted.internet.tcp import Connection
 from twisted.protocols.basic import LineOnlyReceiver
 from twisted.python.failure import Failure
 
@@ -145,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
     (if they send a `PING` command)
     """
 
+    # The transport is going to be an ITCPTransport, but that doesn't have the
+    # (un)registerProducer methods, those are only on the implementation.
+    transport = None  # type: Connection
+
     delimiter = b"\n"
 
     # Valid commands we expect to receive
@@ -189,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         connected_connections.append(self)  # Register connection for metrics
 
+        assert self.transport is not None
         self.transport.registerProducer(self, True)  # For the *Producing callbacks
 
         self._send_pending_commands()
@@ -213,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
                 logger.info(
                     "[%s] Failed to close connection gracefully, aborting", self.id()
                 )
+                assert self.transport is not None
                 self.transport.abortConnection()
         else:
             if now - self.last_sent_command >= PING_TIME:
@@ -302,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
     def close(self):
         logger.warning("[%s] Closing connection", self.id())
         self.time_we_closed = self.clock.time_msec()
+        assert self.transport is not None
         self.transport.loseConnection()
         self.on_connection_closed()
 
@@ -399,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
     def connectionLost(self, reason):
         logger.info("[%s] Replication connection closed: %r", self.id(), reason)
         if isinstance(reason, Failure):
+            assert reason.type is not None
             connection_close_counter.labels(reason.type.__name__).inc()
         else:
             connection_close_counter.labels(reason.__class__.__name__).inc()
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 7cccde097d..2f4d407f94 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -365,6 +365,6 @@ def lazyConnection(
     factory.continueTrying = reconnect
 
     reactor = hs.get_reactor()
-    reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None)
+    reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)
 
     return factory.handler