diff --git a/changelog.d/9543.misc b/changelog.d/9543.misc
new file mode 100644
index 0000000000..14c7b78dd9
--- /dev/null
+++ b/changelog.d/9543.misc
@@ -0,0 +1 @@
+Fix incorrect type hints.
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index e56cf846f5..999aecce5c 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -21,8 +21,10 @@ import threading
from string import Template
import yaml
+from zope.interface import implementer
from twisted.logger import (
+ ILogObserver,
LogBeginner,
STDLibLogObserver,
eventAsText,
@@ -227,7 +229,8 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
threadlocal = threading.local()
- def _log(event):
+ @implementer(ILogObserver)
+ def _log(event: dict) -> None:
if "log_text" in event:
if event["log_text"].startswith("DNSDatagramProtocol starting on "):
return
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 7657697bfa..ffc735ba25 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -361,7 +361,7 @@ class FederationServer(FederationBase):
logger.error(
"Failed to handle PDU %s",
event_id,
- exc_info=(f.type, f.value, f.getTracebackObject()),
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
await concurrently_execute(
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 059064a4eb..66dc886c81 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -285,7 +285,7 @@ class PaginationHandler:
except Exception:
f = Failure()
logger.error(
- "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject())
+ "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore
)
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
finally:
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 4def7d7633..ecd63e6596 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -322,7 +322,8 @@ def _cache_period_from_headers(
def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
cache_controls = {}
- for hdr in headers.getRawHeaders(b"cache-control", []):
+ cache_control_headers = headers.getRawHeaders(b"cache-control") or []
+ for hdr in cache_control_headers:
for directive in hdr.split(b","):
splits = [x.strip() for x in directive.split(b"=", 1)]
k = splits[0].lower()
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 78e27bfb00..1a7ea4fa96 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -669,7 +669,7 @@ def preserve_fn(f):
return g
-def run_in_background(f, *args, **kwargs):
+def run_in_background(f, *args, **kwargs) -> defer.Deferred:
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
deferred returned by the function completes.
@@ -697,8 +697,10 @@ def run_in_background(f, *args, **kwargs):
if isinstance(res, types.CoroutineType):
res = defer.ensureDeferred(res)
+ # At this point we should have a Deferred, if not then f was a synchronous
+ # function, wrap it in a Deferred for consistency.
if not isinstance(res, defer.Deferred):
- return res
+ return defer.succeed(res)
if res.called and not res.paused:
# The function should have maintained the logcontext, so we can
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index f6a6aed35e..20940c8107 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -22,6 +22,7 @@ from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource
+from twisted.web.server import Request, Site
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
@@ -32,7 +33,10 @@ from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.replication.tcp.resource import (
+ ReplicationStreamProtocolFactory,
+ ServerReplicationStreamProtocol,
+)
from synapse.server import HomeServer
from synapse.util import Clock
@@ -59,7 +63,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
- self.server = server_factory.buildProtocol(None)
+ self.server = server_factory.buildProtocol(
+ None
+ ) # type: ServerReplicationStreamProtocol
# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -155,9 +161,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
request_factory = OneShotRequestFactory()
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor)
- channel.requestFactory = request_factory
- channel.site = self.site
+ channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -188,8 +192,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
fetching updates for given stream.
"""
+ path = request.path # type: bytes # type: ignore
self.assertRegex(
- request.path,
+ path,
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
% (stream_name.encode("ascii"),),
)
@@ -390,9 +395,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
request_factory = OneShotRequestFactory()
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor)
- channel.requestFactory = request_factory
- channel.site = self._hs_to_site[hs]
+ channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -475,9 +478,13 @@ class _PushHTTPChannel(HTTPChannel):
makes it very hard to test.
"""
- def __init__(self, reactor: IReactorTime):
+ def __init__(
+ self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
+ ):
super().__init__()
self.reactor = reactor
+ self.requestFactory = request_factory
+ self.site = site
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
diff --git a/tests/server.py b/tests/server.py
index 939a0008ca..863f6da738 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -188,7 +188,7 @@ class FakeSite:
def make_request(
reactor,
- site: Site,
+ site: Union[Site, FakeSite],
method,
path,
content=b"",
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 52ae5c5713..74568b34f8 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -28,7 +28,7 @@ class ToTwistedHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
- self.tx_log.emit(
+ self.tx_log.emit( # type: ignore
twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
)
|