diff --git a/changelog.d/9608.misc b/changelog.d/9608.misc
new file mode 100644
index 0000000000..14c7b78dd9
--- /dev/null
+++ b/changelog.d/9608.misc
@@ -0,0 +1 @@
+Fix incorrect type hints.
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index 34787e0b1e..080ca40287 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -19,7 +19,7 @@ from typing import Any, List, Optional, Type, Union
from twisted.internet import protocol
-class RedisProtocol:
+class RedisProtocol(protocol.Protocol):
def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
async def set(
diff --git a/synapse/http/client.py b/synapse/http/client.py
index d4ab3a2732..1e01e0a9f2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -45,7 +45,9 @@ from twisted.internet.interfaces import (
IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
+ ITCPTransport,
)
+from twisted.internet.protocol import connectionDone
from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
@@ -760,6 +762,8 @@ class BodyExceededMaxSize(Exception):
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data."""
+ transport = None # type: Optional[ITCPTransport]
+
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred
@@ -771,18 +775,21 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
+ assert self.transport is not None
self.transport.abortConnection()
def dataReceived(self, data: bytes) -> None:
self._maybe_fail()
- def connectionLost(self, reason: Failure) -> None:
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
self._maybe_fail()
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
+ transport = None # type: Optional[ITCPTransport]
+
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
@@ -805,9 +812,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
+ assert self.transport is not None
self.transport.abortConnection()
- def connectionLost(self, reason: Failure) -> None:
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index ee909f3fc5..a8894beadf 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -302,7 +302,7 @@ class ReplicationCommandHandler:
hs, outbound_redis_connection
)
hs.get_reactor().connectTCP(
- hs.config.redis.redis_host,
+ hs.config.redis.redis_host.encode(),
hs.config.redis.redis_port,
self._factory,
)
@@ -311,7 +311,7 @@ class ReplicationCommandHandler:
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
- hs.get_reactor().connectTCP(host, port, self._factory)
+ hs.get_reactor().connectTCP(host.encode(), port, self._factory)
def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams."""
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 8e4734b59c..825900f64c 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -56,6 +56,7 @@ from prometheus_client import Counter
from zope.interface import Interface, implementer
from twisted.internet import task
+from twisted.internet.tcp import Connection
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
@@ -145,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
(if they send a `PING` command)
"""
+ # The transport is going to be an ITCPTransport, but that doesn't have the
+ # (un)registerProducer methods, those are only on the implementation.
+ transport = None # type: Connection
+
delimiter = b"\n"
# Valid commands we expect to receive
@@ -189,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
connected_connections.append(self) # Register connection for metrics
+ assert self.transport is not None
self.transport.registerProducer(self, True) # For the *Producing callbacks
self._send_pending_commands()
@@ -213,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
logger.info(
"[%s] Failed to close connection gracefully, aborting", self.id()
)
+ assert self.transport is not None
self.transport.abortConnection()
else:
if now - self.last_sent_command >= PING_TIME:
@@ -302,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def close(self):
logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
+ assert self.transport is not None
self.transport.loseConnection()
self.on_connection_closed()
@@ -399,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def connectionLost(self, reason):
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
+ assert reason.type is not None
connection_close_counter.labels(reason.type.__name__).inc()
else:
connection_close_counter.labels(reason.__class__.__name__).inc()
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 7cccde097d..2f4d407f94 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -365,6 +365,6 @@ def lazyConnection(
factory.continueTrying = reconnect
reactor = hs.get_reactor()
- reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None)
+ reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)
return factory.handler
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 20940c8107..67b7913666 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple
-
-import attr
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
@@ -158,10 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
- request_factory = OneShotRequestFactory()
-
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
+ channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -183,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
- return request_factory.request
+ return channel.request
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
@@ -237,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
- "localhost",
+ b"localhost",
6379,
self.connect_any_redis_attempts,
)
@@ -392,10 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
- request_factory = OneShotRequestFactory()
-
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
+ channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -421,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
clients = self.reactor.tcpClients
while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
- self.assertEqual(host, "localhost")
+ self.assertEqual(host, b"localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
@@ -453,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
self.received_rdata_rows.append((stream_name, token, r))
-@attr.s()
-class OneShotRequestFactory:
- """A simple request factory that generates a single `SynapseRequest` and
- stores it for future use. Can only be used once.
- """
-
- request = attr.ib(default=None)
-
- def __call__(self, *args, **kwargs):
- assert self.request is None
-
- self.request = SynapseRequest(*args, **kwargs)
- return self.request
-
-
class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers.
@@ -479,7 +458,7 @@ class _PushHTTPChannel(HTTPChannel):
"""
def __init__(
- self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
+ self, reactor: IReactorTime, request_factory: Type[Request], site: Site
):
super().__init__()
self.reactor = reactor
@@ -510,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel):
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False
+ def requestDone(self, request):
+ # Store the request for inspection.
+ self.request = request
+ super().requestDone(request)
+
class _PullToPushProducer:
"""A push producer that wraps a pull producer."""
@@ -597,6 +581,8 @@ class FakeRedisPubSubServer:
class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server."""
+ transport = None # type: Optional[FakeTransport]
+
def __init__(self, server: FakeRedisPubSubServer):
self._server = server
self._reader = hiredis.Reader()
@@ -641,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol):
def send(self, msg):
"""Send a message back to the client."""
+ assert self.transport is not None
+
raw = self.encode(msg).encode("utf-8")
self.transport.write(raw)
diff --git a/tests/server.py b/tests/server.py
index 863f6da738..2287d20076 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -16,6 +16,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver,
IReactorTCP,
IResolverSimple,
+ ITransport,
)
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
@@ -467,6 +468,7 @@ def get_clock():
return clock, hs_clock
+@implementer(ITransport)
@attr.s(cmp=False)
class FakeTransport:
"""
|