summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-07-21 08:01:52 -0400
committerGitHub <noreply@github.com>2022-07-21 12:01:52 +0000
commit50122754c8743df5c904e81b634fdfdeea64e795 (patch)
tree3c6b3e46aec5a66a05eaffb688154eb87b198fcb /synapse
parentUse cache store remove base slaved (#13329) (diff)
downloadsynapse-50122754c8743df5c904e81b634fdfdeea64e795.tar.xz
Add missing types to opentracing. (#13345)
After this change `synapse.logging` is fully typed.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/federation/transport/server/_base.py2
-rw-r--r--synapse/handlers/device.py8
-rw-r--r--synapse/handlers/e2e_keys.py16
-rw-r--r--synapse/handlers/e2e_room_keys.py4
-rw-r--r--synapse/logging/opentracing.py44
-rw-r--r--synapse/metrics/background_process_metrics.py2
-rw-r--r--synapse/rest/client/keys.py4
-rw-r--r--synapse/storage/databases/main/deviceinbox.py2
-rw-r--r--synapse/storage/databases/main/devices.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py6
10 files changed, 60 insertions, 32 deletions
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index 84100a5a52..bb0f8d6b7b 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -309,7 +309,7 @@ class BaseFederationServlet:
                 raise
 
             # update the active opentracing span with the authenticated entity
-            set_tag("authenticated_entity", origin)
+            set_tag("authenticated_entity", str(origin))
 
             # if the origin is authenticated and whitelisted, use its span context
             # as the parent.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index c05a170c55..1a8379854c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -118,8 +118,8 @@ class DeviceWorkerHandler:
         ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
         _update_device_from_client_ips(device, ips)
 
-        set_tag("device", device)
-        set_tag("ips", ips)
+        set_tag("device", str(device))
+        set_tag("ips", str(ips))
 
         return device
 
@@ -170,7 +170,7 @@ class DeviceWorkerHandler:
         """
 
         set_tag("user_id", user_id)
-        set_tag("from_token", from_token)
+        set_tag("from_token", str(from_token))
         now_room_key = self.store.get_room_max_token()
 
         room_ids = await self.store.get_rooms_for_user(user_id)
@@ -795,7 +795,7 @@ class DeviceListUpdater:
         """
 
         set_tag("origin", origin)
-        set_tag("edu_content", edu_content)
+        set_tag("edu_content", str(edu_content))
         user_id = edu_content.pop("user_id")
         device_id = edu_content.pop("device_id")
         stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 84c28c480e..c938339ddd 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -138,8 +138,8 @@ class E2eKeysHandler:
                 else:
                     remote_queries[user_id] = device_ids
 
-            set_tag("local_key_query", local_query)
-            set_tag("remote_key_query", remote_queries)
+            set_tag("local_key_query", str(local_query))
+            set_tag("remote_key_query", str(remote_queries))
 
             # First get local devices.
             # A map of destination -> failure response.
@@ -343,7 +343,7 @@ class E2eKeysHandler:
             failure = _exception_to_failure(e)
             failures[destination] = failure
             set_tag("error", True)
-            set_tag("reason", failure)
+            set_tag("reason", str(failure))
 
         return
 
@@ -405,7 +405,7 @@ class E2eKeysHandler:
         Returns:
             A map from user_id -> device_id -> device details
         """
-        set_tag("local_query", query)
+        set_tag("local_query", str(query))
         local_query: List[Tuple[str, Optional[str]]] = []
 
         result_dict: Dict[str, Dict[str, dict]] = {}
@@ -477,8 +477,8 @@ class E2eKeysHandler:
                 domain = get_domain_from_id(user_id)
                 remote_queries.setdefault(domain, {})[user_id] = one_time_keys
 
-        set_tag("local_key_query", local_query)
-        set_tag("remote_key_query", remote_queries)
+        set_tag("local_key_query", str(local_query))
+        set_tag("remote_key_query", str(remote_queries))
 
         results = await self.store.claim_e2e_one_time_keys(local_query)
 
@@ -508,7 +508,7 @@ class E2eKeysHandler:
                 failure = _exception_to_failure(e)
                 failures[destination] = failure
                 set_tag("error", True)
-                set_tag("reason", failure)
+                set_tag("reason", str(failure))
 
         await make_deferred_yieldable(
             defer.gatherResults(
@@ -611,7 +611,7 @@ class E2eKeysHandler:
 
         result = await self.store.count_e2e_one_time_keys(user_id, device_id)
 
-        set_tag("one_time_key_counts", result)
+        set_tag("one_time_key_counts", str(result))
         return {"one_time_key_counts": result}
 
     async def _upload_one_time_keys_for_user(
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 446f509bdc..28dc08c22a 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Dict, Optional, cast
 
 from typing_extensions import Literal
 
@@ -97,7 +97,7 @@ class E2eRoomKeysHandler:
                 user_id, version, room_id, session_id
             )
 
-            log_kv(results)
+            log_kv(cast(JsonDict, results))
             return results
 
     @trace
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 17e729f0c7..ad5cbf46a4 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -182,6 +182,8 @@ from typing import (
     Type,
     TypeVar,
     Union,
+    cast,
+    overload,
 )
 
 import attr
@@ -328,6 +330,7 @@ class _Sentinel(enum.Enum):
 
 P = ParamSpec("P")
 R = TypeVar("R")
+T = TypeVar("T")
 
 
 def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
@@ -343,22 +346,43 @@ def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
     return _only_if_tracing_inner
 
 
-def ensure_active_span(message: str, ret=None):
+@overload
+def ensure_active_span(
+    message: str,
+) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]:
+    ...
+
+
+@overload
+def ensure_active_span(
+    message: str, ret: T
+) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]:
+    ...
+
+
+def ensure_active_span(
+    message: str, ret: Optional[T] = None
+) -> Callable[[Callable[P, R]], Callable[P, Union[Optional[T], R]]]:
     """Executes the operation only if opentracing is enabled and there is an active span.
     If there is no active span it logs message at the error level.
 
     Args:
         message: Message which fills in "There was no active span when trying to %s"
             in the error log if there is no active span and opentracing is enabled.
-        ret (object): return value if opentracing is None or there is no active span.
+        ret: return value if opentracing is None or there is no active span.
 
-    Returns (object): The result of the func or ret if opentracing is disabled or there
+    Returns:
+        The result of the func, falling back to ret if opentracing is disabled or there
         was no active span.
     """
 
-    def ensure_active_span_inner_1(func):
+    def ensure_active_span_inner_1(
+        func: Callable[P, R]
+    ) -> Callable[P, Union[Optional[T], R]]:
         @wraps(func)
-        def ensure_active_span_inner_2(*args, **kwargs):
+        def ensure_active_span_inner_2(
+            *args: P.args, **kwargs: P.kwargs
+        ) -> Union[Optional[T], R]:
             if not opentracing:
                 return ret
 
@@ -464,7 +488,7 @@ def start_active_span(
     finish_on_close: bool = True,
     *,
     tracer: Optional["opentracing.Tracer"] = None,
-):
+) -> "opentracing.Scope":
     """Starts an active opentracing span.
 
     Records the start time for the span, and sets it as the "active span" in the
@@ -502,7 +526,7 @@ def start_active_span_follows_from(
     *,
     inherit_force_tracing: bool = False,
     tracer: Optional["opentracing.Tracer"] = None,
-):
+) -> "opentracing.Scope":
     """Starts an active opentracing span, with additional references to previous spans
 
     Args:
@@ -717,7 +741,9 @@ def inject_response_headers(response_headers: Headers) -> None:
         response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
 
 
-@ensure_active_span("get the active span context as a dict", ret={})
+@ensure_active_span(
+    "get the active span context as a dict", ret=cast(Dict[str, str], {})
+)
 def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
     """
     Gets a span context as a dict. This can be used instead of manually
@@ -886,7 +912,7 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]:
         for i, arg in enumerate(argspec.args[1:]):
             set_tag("ARG_" + arg, args[i])  # type: ignore[index]
         set_tag("args", args[len(argspec.args) :])  # type: ignore[index]
-        set_tag("kwargs", kwargs)
+        set_tag("kwargs", str(kwargs))
         return func(*args, **kwargs)
 
     return _tag_args_inner
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index eef3462e10..7a1516d3a8 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -235,7 +235,7 @@ def run_as_background_process(
                         f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)}
                     )
                 else:
-                    ctx = nullcontext()
+                    ctx = nullcontext()  # type: ignore[assignment]
                 with ctx:
                     return await func(*args, **kwargs)
             except Exception:
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index eb1b85721f..e3f454896a 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -208,7 +208,9 @@ class KeyChangesServlet(RestServlet):
 
         # We want to enforce they do pass us one, but we ignore it and return
         # changes after the "to" as well as before.
-        set_tag("to", parse_string(request, "to"))
+        #
+        # XXX This does not enforce that "to" is passed.
+        set_tag("to", str(parse_string(request, "to")))
 
         from_token = await StreamToken.from_string(self.store, from_token_string)
 
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 422e0e65ca..73c95ffb6f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -436,7 +436,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             (user_id, device_id), None
         )
 
-        set_tag("last_deleted_stream_id", last_deleted_stream_id)
+        set_tag("last_deleted_stream_id", str(last_deleted_stream_id))
 
         if last_deleted_stream_id:
             has_changed = self._device_inbox_stream_cache.has_entity_changed(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 7a6ed332aa..ca0fe8c4be 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -706,8 +706,8 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
             else:
                 results[user_id] = await self.get_cached_devices_for_user(user_id)
 
-        set_tag("in_cache", results)
-        set_tag("not_in_cache", user_ids_not_in_cache)
+        set_tag("in_cache", str(results))
+        set_tag("not_in_cache", str(user_ids_not_in_cache))
 
         return user_ids_not_in_cache, results
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 60f622ad71..46c0d06157 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -146,7 +146,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             key data.  The key data will be a dict in the same format as the
             DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
         """
-        set_tag("query_list", query_list)
+        set_tag("query_list", str(query_list))
         if not query_list:
             return {}
 
@@ -418,7 +418,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
-            set_tag("new_keys", new_keys)
+            set_tag("new_keys", str(new_keys))
             # We are protected from race between lookup and insertion due to
             # a unique constraint. If there is a race of two calls to
             # `add_e2e_one_time_keys` then they'll conflict and we will only
@@ -1161,7 +1161,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
             set_tag("time_now", time_now)
-            set_tag("device_keys", device_keys)
+            set_tag("device_keys", str(device_keys))
 
             old_key_json = self.db_pool.simple_select_one_onecol_txn(
                 txn,