summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/http/connectproxyclient.py12
-rw-r--r--synapse/http/request_metrics.py50
-rw-r--r--synapse/logging/context.py4
-rw-r--r--synapse/metrics/__init__.py14
4 files changed, 53 insertions, 27 deletions
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index c577142268..fbafffd69b 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -84,7 +84,11 @@ class HTTPConnectProxyEndpoint:
     def __repr__(self):
         return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
 
-    def connect(self, protocolFactory: ClientFactory):
+    # Mypy encounters a false positive here: it complains that ClientFactory
+    # is incompatible with IProtocolFactory. But ClientFactory inherits from
+    # Factory, which implements IProtocolFactory. So I think this is a bug
+    # in mypy-zope.
+    def connect(self, protocolFactory: ClientFactory):  # type: ignore[override]
         f = HTTPProxiedClientFactory(
             self._host, self._port, protocolFactory, self._proxy_creds
         )
@@ -119,13 +123,15 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
         self.dst_port = dst_port
         self.wrapped_factory = wrapped_factory
         self.proxy_creds = proxy_creds
-        self.on_connection = defer.Deferred()
+        self.on_connection: "defer.Deferred[None]" = defer.Deferred()
 
     def startedConnecting(self, connector):
         return self.wrapped_factory.startedConnecting(connector)
 
     def buildProtocol(self, addr):
         wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
+        if wrapped_protocol is None:
+            raise TypeError("buildProtocol produced None instead of a Protocol")
 
         return HTTPConnectProtocol(
             self.dst_host,
@@ -235,7 +241,7 @@ class HTTPConnectSetupClient(http.HTTPClient):
         self.host = host
         self.port = port
         self.proxy_creds = proxy_creds
-        self.on_connected = defer.Deferred()
+        self.on_connected: "defer.Deferred[None]" = defer.Deferred()
 
     def connectionMade(self):
         logger.debug("Connected to proxy, sending CONNECT")
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 602f93c497..4886626d50 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -15,6 +15,8 @@
 
 import logging
 import threading
+import traceback
+from typing import Dict, Mapping, Set, Tuple
 
 from prometheus_client.core import Counter, Histogram
 
@@ -105,19 +107,14 @@ in_flight_requests_db_sched_duration = Counter(
     ["method", "servlet"],
 )
 
-# The set of all in flight requests, set[RequestMetrics]
-_in_flight_requests = set()
+_in_flight_requests: Set["RequestMetrics"] = set()
 
 # Protects the _in_flight_requests set from concurrent access
 _in_flight_requests_lock = threading.Lock()
 
 
-def _get_in_flight_counts():
-    """Returns a count of all in flight requests by (method, server_name)
-
-    Returns:
-        dict[tuple[str, str], int]
-    """
+def _get_in_flight_counts() -> Mapping[Tuple[str, ...], int]:
+    """Returns a count of all in flight requests by (method, server_name)"""
     # Cast to a list to prevent it changing while the Prometheus
     # thread is collecting metrics
     with _in_flight_requests_lock:
@@ -127,8 +124,9 @@ def _get_in_flight_counts():
         rm.update_metrics()
 
     # Map from (method, name) -> int, the number of in flight requests of that
-    # type
-    counts = {}
+    # type. The key type is Tuple[str, str], but we leave the length unspecified
+    # for compatability with LaterGauge's annotations.
+    counts: Dict[Tuple[str, ...], int] = {}
     for rm in reqs:
         key = (rm.method, rm.name)
         counts[key] = counts.get(key, 0) + 1
@@ -145,15 +143,21 @@ LaterGauge(
 
 
 class RequestMetrics:
-    def start(self, time_sec, name, method):
-        self.start = time_sec
+    def start(self, time_sec: float, name: str, method: str) -> None:
+        self.start_ts = time_sec
         self.start_context = current_context()
         self.name = name
         self.method = method
 
-        # _request_stats records resource usage that we have already added
-        # to the "in flight" metrics.
-        self._request_stats = self.start_context.get_resource_usage()
+        if self.start_context:
+            # _request_stats records resource usage that we have already added
+            # to the "in flight" metrics.
+            self._request_stats = self.start_context.get_resource_usage()
+        else:
+            logger.error(
+                "Tried to start a RequestMetric from the sentinel context.\n%s",
+                "".join(traceback.format_stack()),
+            )
 
         with _in_flight_requests_lock:
             _in_flight_requests.add(self)
@@ -169,12 +173,18 @@ class RequestMetrics:
             tag = context.tag
 
             if context != self.start_context:
-                logger.warning(
+                logger.error(
                     "Context have unexpectedly changed %r, %r",
                     context,
                     self.start_context,
                 )
                 return
+        else:
+            logger.error(
+                "Trying to stop RequestMetrics in the sentinel context.\n%s",
+                "".join(traceback.format_stack()),
+            )
+            return
 
         response_code = str(response_code)
 
@@ -183,7 +193,7 @@ class RequestMetrics:
         response_count.labels(self.method, self.name, tag).inc()
 
         response_timer.labels(self.method, self.name, tag, response_code).observe(
-            time_sec - self.start
+            time_sec - self.start_ts
         )
 
         resource_usage = context.get_resource_usage()
@@ -213,6 +223,12 @@ class RequestMetrics:
 
     def update_metrics(self):
         """Updates the in flight metrics with values from this request."""
+        if not self.start_context:
+            logger.error(
+                "Tried to update a RequestMetric from the sentinel context.\n%s",
+                "".join(traceback.format_stack()),
+            )
+            return
         new_stats = self.start_context.get_resource_usage()
 
         diff = new_stats - self._request_stats
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index bdc0187743..d8ae3188b7 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -220,7 +220,7 @@ class _Sentinel:
         self.scope = None
         self.tag = None
 
-    def __str__(self):
+    def __str__(self) -> str:
         return "sentinel"
 
     def copy_to(self, record):
@@ -241,7 +241,7 @@ class _Sentinel:
     def record_event_fetch(self, event_count):
         pass
 
-    def __bool__(self):
+    def __bool__(self) -> Literal[False]:
         return False
 
 
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index f237b8a236..e902109af3 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -20,7 +20,7 @@ import os
 import platform
 import threading
 import time
-from typing import Callable, Dict, Iterable, Optional, Tuple, Union
+from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
 
 import attr
 from prometheus_client import Counter, Gauge, Histogram
@@ -67,7 +67,11 @@ class LaterGauge:
     labels = attr.ib(hash=False, type=Optional[Iterable[str]])
     # callback: should either return a value (if there are no labels for this metric),
     # or dict mapping from a label tuple to a value
-    caller = attr.ib(type=Callable[[], Union[Dict[Tuple[str, ...], float], float]])
+    caller = attr.ib(
+        type=Callable[
+            [], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
+        ]
+    )
 
     def collect(self):
 
@@ -80,11 +84,11 @@ class LaterGauge:
             yield g
             return
 
-        if isinstance(calls, dict):
+        if isinstance(calls, (int, float)):
+            g.add_metric([], calls)
+        else:
             for k, v in calls.items():
                 g.add_metric(k, v)
-        else:
-            g.add_metric([], calls)
 
         yield g