diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index c577142268..fbafffd69b 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -84,7 +84,11 @@ class HTTPConnectProxyEndpoint:
def __repr__(self):
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
- def connect(self, protocolFactory: ClientFactory):
+ # Mypy encounters a false positive here: it complains that ClientFactory
+ # is incompatible with IProtocolFactory. But ClientFactory inherits from
+ # Factory, which implements IProtocolFactory. So I think this is a bug
+ # in mypy-zope.
+ def connect(self, protocolFactory: ClientFactory): # type: ignore[override]
f = HTTPProxiedClientFactory(
self._host, self._port, protocolFactory, self._proxy_creds
)
@@ -119,13 +123,15 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
self.dst_port = dst_port
self.wrapped_factory = wrapped_factory
self.proxy_creds = proxy_creds
- self.on_connection = defer.Deferred()
+ self.on_connection: "defer.Deferred[None]" = defer.Deferred()
def startedConnecting(self, connector):
return self.wrapped_factory.startedConnecting(connector)
def buildProtocol(self, addr):
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
+ if wrapped_protocol is None:
+ raise TypeError("buildProtocol produced None instead of a Protocol")
return HTTPConnectProtocol(
self.dst_host,
@@ -235,7 +241,7 @@ class HTTPConnectSetupClient(http.HTTPClient):
self.host = host
self.port = port
self.proxy_creds = proxy_creds
- self.on_connected = defer.Deferred()
+ self.on_connected: "defer.Deferred[None]" = defer.Deferred()
def connectionMade(self):
logger.debug("Connected to proxy, sending CONNECT")
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 602f93c497..4886626d50 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -15,6 +15,8 @@
import logging
import threading
+import traceback
+from typing import Dict, Mapping, Set, Tuple
from prometheus_client.core import Counter, Histogram
@@ -105,19 +107,14 @@ in_flight_requests_db_sched_duration = Counter(
["method", "servlet"],
)
-# The set of all in flight requests, set[RequestMetrics]
-_in_flight_requests = set()
+_in_flight_requests: Set["RequestMetrics"] = set()
# Protects the _in_flight_requests set from concurrent access
_in_flight_requests_lock = threading.Lock()
-def _get_in_flight_counts():
- """Returns a count of all in flight requests by (method, server_name)
-
- Returns:
- dict[tuple[str, str], int]
- """
+def _get_in_flight_counts() -> Mapping[Tuple[str, ...], int]:
+ """Returns a count of all in flight requests by (method, server_name)"""
# Cast to a list to prevent it changing while the Prometheus
# thread is collecting metrics
with _in_flight_requests_lock:
@@ -127,8 +124,9 @@ def _get_in_flight_counts():
rm.update_metrics()
# Map from (method, name) -> int, the number of in flight requests of that
- # type
- counts = {}
+ # type. The key type is Tuple[str, str], but we leave the length unspecified
+ # for compatability with LaterGauge's annotations.
+ counts: Dict[Tuple[str, ...], int] = {}
for rm in reqs:
key = (rm.method, rm.name)
counts[key] = counts.get(key, 0) + 1
@@ -145,15 +143,21 @@ LaterGauge(
class RequestMetrics:
- def start(self, time_sec, name, method):
- self.start = time_sec
+ def start(self, time_sec: float, name: str, method: str) -> None:
+ self.start_ts = time_sec
self.start_context = current_context()
self.name = name
self.method = method
- # _request_stats records resource usage that we have already added
- # to the "in flight" metrics.
- self._request_stats = self.start_context.get_resource_usage()
+ if self.start_context:
+ # _request_stats records resource usage that we have already added
+ # to the "in flight" metrics.
+ self._request_stats = self.start_context.get_resource_usage()
+ else:
+ logger.error(
+ "Tried to start a RequestMetric from the sentinel context.\n%s",
+ "".join(traceback.format_stack()),
+ )
with _in_flight_requests_lock:
_in_flight_requests.add(self)
@@ -169,12 +173,18 @@ class RequestMetrics:
tag = context.tag
if context != self.start_context:
- logger.warning(
+ logger.error(
"Context have unexpectedly changed %r, %r",
context,
self.start_context,
)
return
+ else:
+ logger.error(
+ "Trying to stop RequestMetrics in the sentinel context.\n%s",
+ "".join(traceback.format_stack()),
+ )
+ return
response_code = str(response_code)
@@ -183,7 +193,7 @@ class RequestMetrics:
response_count.labels(self.method, self.name, tag).inc()
response_timer.labels(self.method, self.name, tag, response_code).observe(
- time_sec - self.start
+ time_sec - self.start_ts
)
resource_usage = context.get_resource_usage()
@@ -213,6 +223,12 @@ class RequestMetrics:
def update_metrics(self):
"""Updates the in flight metrics with values from this request."""
+ if not self.start_context:
+ logger.error(
+ "Tried to update a RequestMetric from the sentinel context.\n%s",
+ "".join(traceback.format_stack()),
+ )
+ return
new_stats = self.start_context.get_resource_usage()
diff = new_stats - self._request_stats
|