diff --git a/changelog.d/11164.misc b/changelog.d/11164.misc
new file mode 100644
index 0000000000..751da49183
--- /dev/null
+++ b/changelog.d/11164.misc
@@ -0,0 +1 @@
+Add type hints so that `synapse.http` passes `mypy` checks.
\ No newline at end of file
diff --git a/mypy.ini b/mypy.ini
index c5f44aea39..8f5386c179 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -16,6 +16,7 @@ no_implicit_optional = True
files =
scripts-dev/sign_json,
+ synapse/__init__.py,
synapse/api,
synapse/appservice,
synapse/config,
@@ -31,16 +32,7 @@ files =
synapse/federation,
synapse/groups,
synapse/handlers,
- synapse/http/additional_resource.py,
- synapse/http/client.py,
- synapse/http/federation/matrix_federation_agent.py,
- synapse/http/federation/srv_resolver.py,
- synapse/http/federation/well_known_resolver.py,
- synapse/http/matrixfederationclient.py,
- synapse/http/proxyagent.py,
- synapse/http/servlet.py,
- synapse/http/server.py,
- synapse/http/site.py,
+ synapse/http,
synapse/logging,
synapse/metrics,
synapse/module_api,
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index c577142268..fbafffd69b 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -84,7 +84,11 @@ class HTTPConnectProxyEndpoint:
def __repr__(self):
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
- def connect(self, protocolFactory: ClientFactory):
+ # Mypy encounters a false positive here: it complains that ClientFactory
+ # is incompatible with IProtocolFactory. But ClientFactory inherits from
+ # Factory, which implements IProtocolFactory. So I think this is a bug
+ # in mypy-zope.
+ def connect(self, protocolFactory: ClientFactory): # type: ignore[override]
f = HTTPProxiedClientFactory(
self._host, self._port, protocolFactory, self._proxy_creds
)
@@ -119,13 +123,15 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
self.dst_port = dst_port
self.wrapped_factory = wrapped_factory
self.proxy_creds = proxy_creds
- self.on_connection = defer.Deferred()
+ self.on_connection: "defer.Deferred[None]" = defer.Deferred()
def startedConnecting(self, connector):
return self.wrapped_factory.startedConnecting(connector)
def buildProtocol(self, addr):
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
+ if wrapped_protocol is None:
+ raise TypeError("buildProtocol produced None instead of a Protocol")
return HTTPConnectProtocol(
self.dst_host,
@@ -235,7 +241,7 @@ class HTTPConnectSetupClient(http.HTTPClient):
self.host = host
self.port = port
self.proxy_creds = proxy_creds
- self.on_connected = defer.Deferred()
+ self.on_connected: "defer.Deferred[None]" = defer.Deferred()
def connectionMade(self):
logger.debug("Connected to proxy, sending CONNECT")
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 602f93c497..4886626d50 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -15,6 +15,8 @@
import logging
import threading
+import traceback
+from typing import Dict, Mapping, Set, Tuple
from prometheus_client.core import Counter, Histogram
@@ -105,19 +107,14 @@ in_flight_requests_db_sched_duration = Counter(
["method", "servlet"],
)
-# The set of all in flight requests, set[RequestMetrics]
-_in_flight_requests = set()
+_in_flight_requests: Set["RequestMetrics"] = set()
# Protects the _in_flight_requests set from concurrent access
_in_flight_requests_lock = threading.Lock()
-def _get_in_flight_counts():
- """Returns a count of all in flight requests by (method, server_name)
-
- Returns:
- dict[tuple[str, str], int]
- """
+def _get_in_flight_counts() -> Mapping[Tuple[str, ...], int]:
+ """Returns a count of all in flight requests by (method, server_name)"""
# Cast to a list to prevent it changing while the Prometheus
# thread is collecting metrics
with _in_flight_requests_lock:
@@ -127,8 +124,9 @@ def _get_in_flight_counts():
rm.update_metrics()
# Map from (method, name) -> int, the number of in flight requests of that
- # type
- counts = {}
+ # type. The key type is Tuple[str, str], but we leave the length unspecified
+ # for compatability with LaterGauge's annotations.
+ counts: Dict[Tuple[str, ...], int] = {}
for rm in reqs:
key = (rm.method, rm.name)
counts[key] = counts.get(key, 0) + 1
@@ -145,15 +143,21 @@ LaterGauge(
class RequestMetrics:
- def start(self, time_sec, name, method):
- self.start = time_sec
+ def start(self, time_sec: float, name: str, method: str) -> None:
+ self.start_ts = time_sec
self.start_context = current_context()
self.name = name
self.method = method
- # _request_stats records resource usage that we have already added
- # to the "in flight" metrics.
- self._request_stats = self.start_context.get_resource_usage()
+ if self.start_context:
+ # _request_stats records resource usage that we have already added
+ # to the "in flight" metrics.
+ self._request_stats = self.start_context.get_resource_usage()
+ else:
+ logger.error(
+ "Tried to start a RequestMetric from the sentinel context.\n%s",
+ "".join(traceback.format_stack()),
+ )
with _in_flight_requests_lock:
_in_flight_requests.add(self)
@@ -169,12 +173,18 @@ class RequestMetrics:
tag = context.tag
if context != self.start_context:
- logger.warning(
+ logger.error(
"Context have unexpectedly changed %r, %r",
context,
self.start_context,
)
return
+ else:
+ logger.error(
+ "Trying to stop RequestMetrics in the sentinel context.\n%s",
+ "".join(traceback.format_stack()),
+ )
+ return
response_code = str(response_code)
@@ -183,7 +193,7 @@ class RequestMetrics:
response_count.labels(self.method, self.name, tag).inc()
response_timer.labels(self.method, self.name, tag, response_code).observe(
- time_sec - self.start
+ time_sec - self.start_ts
)
resource_usage = context.get_resource_usage()
@@ -213,6 +223,12 @@ class RequestMetrics:
def update_metrics(self):
"""Updates the in flight metrics with values from this request."""
+ if not self.start_context:
+ logger.error(
+ "Tried to update a RequestMetric from the sentinel context.\n%s",
+ "".join(traceback.format_stack()),
+ )
+ return
new_stats = self.start_context.get_resource_usage()
diff = new_stats - self._request_stats
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index bdc0187743..d8ae3188b7 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -220,7 +220,7 @@ class _Sentinel:
self.scope = None
self.tag = None
- def __str__(self):
+ def __str__(self) -> str:
return "sentinel"
def copy_to(self, record):
@@ -241,7 +241,7 @@ class _Sentinel:
def record_event_fetch(self, event_count):
pass
- def __bool__(self):
+ def __bool__(self) -> Literal[False]:
return False
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index f237b8a236..e902109af3 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -20,7 +20,7 @@ import os
import platform
import threading
import time
-from typing import Callable, Dict, Iterable, Optional, Tuple, Union
+from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
import attr
from prometheus_client import Counter, Gauge, Histogram
@@ -67,7 +67,11 @@ class LaterGauge:
labels = attr.ib(hash=False, type=Optional[Iterable[str]])
# callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value
- caller = attr.ib(type=Callable[[], Union[Dict[Tuple[str, ...], float], float]])
+ caller = attr.ib(
+ type=Callable[
+ [], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
+ ]
+ )
def collect(self):
@@ -80,11 +84,11 @@ class LaterGauge:
yield g
return
- if isinstance(calls, dict):
+ if isinstance(calls, (int, float)):
+ g.add_metric([], calls)
+ else:
for k, v in calls.items():
g.add_metric(k, v)
- else:
- g.add_metric([], calls)
yield g
|