summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9543.misc1
-rw-r--r--synapse/config/logger.py5
-rw-r--r--synapse/federation/federation_server.py2
-rw-r--r--synapse/handlers/pagination.py2
-rw-r--r--synapse/http/federation/well_known_resolver.py3
-rw-r--r--synapse/logging/context.py6
-rw-r--r--tests/replication/_base.py27
-rw-r--r--tests/server.py2
-rw-r--r--tests/test_utils/logging_setup.py2
9 files changed, 32 insertions, 18 deletions
diff --git a/changelog.d/9543.misc b/changelog.d/9543.misc
new file mode 100644
index 0000000000..14c7b78dd9
--- /dev/null
+++ b/changelog.d/9543.misc
@@ -0,0 +1 @@
+Fix incorrect type hints.
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index e56cf846f5..999aecce5c 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -21,8 +21,10 @@ import threading
 from string import Template
 
 import yaml
+from zope.interface import implementer
 
 from twisted.logger import (
+    ILogObserver,
     LogBeginner,
     STDLibLogObserver,
     eventAsText,
@@ -227,7 +229,8 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
 
     threadlocal = threading.local()
 
-    def _log(event):
+    @implementer(ILogObserver)
+    def _log(event: dict) -> None:
         if "log_text" in event:
             if event["log_text"].startswith("DNSDatagramProtocol starting on "):
                 return
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 7657697bfa..ffc735ba25 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -361,7 +361,7 @@ class FederationServer(FederationBase):
                             logger.error(
                                 "Failed to handle PDU %s",
                                 event_id,
-                                exc_info=(f.type, f.value, f.getTracebackObject()),
+                                exc_info=(f.type, f.value, f.getTracebackObject()),  # type: ignore
                             )
 
         await concurrently_execute(
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 059064a4eb..66dc886c81 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -285,7 +285,7 @@ class PaginationHandler:
         except Exception:
             f = Failure()
             logger.error(
-                "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject())
+                "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject())  # type: ignore
             )
             self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
         finally:
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 4def7d7633..ecd63e6596 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -322,7 +322,8 @@ def _cache_period_from_headers(
 
 def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
     cache_controls = {}
-    for hdr in headers.getRawHeaders(b"cache-control", []):
+    cache_control_headers = headers.getRawHeaders(b"cache-control") or []
+    for hdr in cache_control_headers:
         for directive in hdr.split(b","):
             splits = [x.strip() for x in directive.split(b"=", 1)]
             k = splits[0].lower()
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 78e27bfb00..1a7ea4fa96 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -669,7 +669,7 @@ def preserve_fn(f):
     return g
 
 
-def run_in_background(f, *args, **kwargs):
+def run_in_background(f, *args, **kwargs) -> defer.Deferred:
     """Calls a function, ensuring that the current context is restored after
     return from the function, and that the sentinel context is set once the
     deferred returned by the function completes.
@@ -697,8 +697,10 @@ def run_in_background(f, *args, **kwargs):
     if isinstance(res, types.CoroutineType):
         res = defer.ensureDeferred(res)
 
+    # At this point we should have a Deferred, if not then f was a synchronous
+    # function, wrap it in a Deferred for consistency.
     if not isinstance(res, defer.Deferred):
-        return res
+        return defer.succeed(res)
 
     if res.called and not res.paused:
         # The function should have maintained the logcontext, so we can
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index f6a6aed35e..20940c8107 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -22,6 +22,7 @@ from twisted.internet.protocol import Protocol
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
 from twisted.web.resource import Resource
+from twisted.web.server import Request, Site
 
 from synapse.app.generic_worker import (
     GenericWorkerReplicationHandler,
@@ -32,7 +33,10 @@ from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.replication.tcp.resource import (
+    ReplicationStreamProtocolFactory,
+    ServerReplicationStreamProtocol,
+)
 from synapse.server import HomeServer
 from synapse.util import Clock
 
@@ -59,7 +63,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
         self.streamer = hs.get_replication_streamer()
-        self.server = server_factory.buildProtocol(None)
+        self.server = server_factory.buildProtocol(
+            None
+        )  # type: ServerReplicationStreamProtocol
 
         # Make a new HomeServer object for the worker
         self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -155,9 +161,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         request_factory = OneShotRequestFactory()
 
         # Set up the server side protocol
-        channel = _PushHTTPChannel(self.reactor)
-        channel.requestFactory = request_factory
-        channel.site = self.site
+        channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -188,8 +192,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         fetching updates for given stream.
         """
 
+        path = request.path  # type: bytes  # type: ignore
         self.assertRegex(
-            request.path,
+            path,
             br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
             % (stream_name.encode("ascii"),),
         )
@@ -390,9 +395,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         request_factory = OneShotRequestFactory()
 
         # Set up the server side protocol
-        channel = _PushHTTPChannel(self.reactor)
-        channel.requestFactory = request_factory
-        channel.site = self._hs_to_site[hs]
+        channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -475,9 +478,13 @@ class _PushHTTPChannel(HTTPChannel):
     makes it very hard to test.
     """
 
-    def __init__(self, reactor: IReactorTime):
+    def __init__(
+        self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
+    ):
         super().__init__()
         self.reactor = reactor
+        self.requestFactory = request_factory
+        self.site = site
 
         self._pull_to_push_producer = None  # type: Optional[_PullToPushProducer]
 
diff --git a/tests/server.py b/tests/server.py
index 939a0008ca..863f6da738 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -188,7 +188,7 @@ class FakeSite:
 
 def make_request(
     reactor,
-    site: Site,
+    site: Union[Site, FakeSite],
     method,
     path,
     content=b"",
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 52ae5c5713..74568b34f8 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -28,7 +28,7 @@ class ToTwistedHandler(logging.Handler):
     def emit(self, record):
         log_entry = self.format(record)
         log_level = record.levelname.lower().replace("warning", "warn")
-        self.tx_log.emit(
+        self.tx_log.emit(  # type: ignore
             twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
         )