diff --git a/changelog.d/13328.misc b/changelog.d/13328.misc
new file mode 100644
index 0000000000..d15fb5fc37
--- /dev/null
+++ b/changelog.d/13328.misc
@@ -0,0 +1 @@
+Add type hints to `trace` decorator.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 66e6305562..7c450ecad0 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -217,7 +217,7 @@ class FederationClient(FederationBase):
)
async def claim_client_keys(
- self, destination: str, content: JsonDict, timeout: int
+ self, destination: str, content: JsonDict, timeout: Optional[int]
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 9e84bd677e..32074b8ca6 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -619,7 +619,7 @@ class TransportLayerClient:
)
async def claim_client_keys(
- self, destination: str, query_content: JsonDict, timeout: int
+ self, destination: str, query_content: JsonDict, timeout: Optional[int]
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 52bb5c9c55..84c28c480e 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@@ -92,7 +92,11 @@ class E2eKeysHandler:
@trace
async def query_devices(
- self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
+ self,
+ query_body: JsonDict,
+ timeout: int,
+ from_user_id: str,
+ from_device_id: Optional[str],
) -> JsonDict:
"""Handle a device key query from a client
@@ -120,9 +124,7 @@ class E2eKeysHandler:
the number of in-flight queries at a time.
"""
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
- device_keys_query: Dict[str, Iterable[str]] = query_body.get(
- "device_keys", {}
- )
+ device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
# separate users by domain.
# make a map from domain to user_id to device_ids
@@ -392,7 +394,7 @@ class E2eKeysHandler:
@trace
async def query_local_devices(
- self, query: Dict[str, Optional[List[str]]]
+ self, query: Mapping[str, Optional[List[str]]]
) -> Dict[str, Dict[str, dict]]:
"""Get E2E device keys for local users
@@ -461,7 +463,7 @@ class E2eKeysHandler:
@trace
async def claim_one_time_keys(
- self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+ self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
) -> JsonDict:
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 50c57940f9..17e729f0c7 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -84,14 +84,13 @@ the function becomes the operation name for the span.
return something_usual_and_useful
-Operation names can be explicitly set for a function by passing the
-operation name to ``trace``
+Operation names can be explicitly set for a function by using ``trace_with_opname``:
.. code-block:: python
- from synapse.logging.opentracing import trace
+ from synapse.logging.opentracing import trace_with_opname
- @trace(opname="a_better_operation_name")
+ @trace_with_opname("a_better_operation_name")
def interesting_badly_named_function(*args, **kwargs):
# Does all kinds of cool and expected things
return something_usual_and_useful
@@ -798,33 +797,31 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
# Tracing decorators
-def trace(func=None, opname: Optional[str] = None):
+def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
- Decorator to trace a function.
- Sets the operation name to that of the function's or that given
- as operation_name. See the module's doc string for usage
- examples.
+ Decorator to trace a function with a custom opname.
+
+ See the module's doc string for usage examples.
+
"""
- def decorator(func):
+ def decorator(func: Callable[P, R]) -> Callable[P, R]:
if opentracing is None:
return func # type: ignore[unreachable]
- _opname = opname if opname else func.__name__
-
if inspect.iscoroutinefunction(func):
@wraps(func)
- async def _trace_inner(*args, **kwargs):
- with start_active_span(_opname):
- return await func(*args, **kwargs)
+ async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+ with start_active_span(opname):
+ return await func(*args, **kwargs) # type: ignore[misc]
else:
# The other case here handles both sync functions and those
# decorated with inlineDeferred.
@wraps(func)
- def _trace_inner(*args, **kwargs):
- scope = start_active_span(_opname)
+ def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+ scope = start_active_span(opname)
scope.__enter__()
try:
@@ -858,12 +855,21 @@ def trace(func=None, opname: Optional[str] = None):
scope.__exit__(type(e), None, e.__traceback__)
raise
- return _trace_inner
+ return _trace_inner # type: ignore[return-value]
- if func:
- return decorator(func)
- else:
- return decorator
+ return decorator
+
+
+def trace(func: Callable[P, R]) -> Callable[P, R]:
+ """
+ Decorator to trace a function.
+
+ Sets the operation name to that of the function's name.
+
+ See the module's doc string for usage examples.
+ """
+
+ return trace_with_opname(func.__name__)(func)
def tag_args(func: Callable[P, R]) -> Callable[P, R]:
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index a4ae4040c3..561ad5bf04 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -29,7 +29,7 @@ from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer, is_method_cancellable
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import trace_with_opname
from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@@ -196,7 +196,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
"ascii"
)
- @trace(opname="outgoing_replication_request")
+ @trace_with_opname("outgoing_replication_request")
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name:
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index ce806e3c11..eb1b85721f 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -26,7 +26,7 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
from synapse.types import JsonDict, StreamToken
from ._base import client_patterns, interactive_auth_handler
@@ -71,7 +71,7 @@ class KeyUploadServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()
- @trace(opname="upload_keys")
+ @trace_with_opname("upload_keys")
async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
) -> Tuple[int, JsonDict]:
diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py
index 37e39570f6..f7081f638e 100644
--- a/synapse/rest/client/room_keys.py
+++ b/synapse/rest/client/room_keys.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple, cast
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
@@ -127,7 +127,7 @@ class RoomKeysServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
- version = parse_string(request, "version")
+ version = parse_string(request, "version", required=True)
if session_id:
body = {"sessions": {session_id: body}}
@@ -196,8 +196,11 @@ class RoomKeysServlet(RestServlet):
user_id = requester.user.to_string()
version = parse_string(request, "version", required=True)
- room_keys = await self.e2e_room_keys_handler.get_room_keys(
- user_id, version, room_id, session_id
+ room_keys = cast(
+ JsonDict,
+ await self.e2e_room_keys_handler.get_room_keys(
+ user_id, version, room_id, session_id
+ ),
)
# Convert room_keys to the right format to return.
@@ -240,7 +243,7 @@ class RoomKeysServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
- version = parse_string(request, "version")
+ version = parse_string(request, "version", required=True)
ret = await self.e2e_room_keys_handler.delete_room_keys(
user_id, version, room_id, session_id
diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py
index 3322c8ef48..1a8e9a96d4 100644
--- a/synapse/rest/client/sendtodevice.py
+++ b/synapse/rest/client/sendtodevice.py
@@ -19,7 +19,7 @@ from synapse.http import servlet
from synapse.http.server import HttpServer
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import set_tag, trace
+from synapse.logging.opentracing import set_tag, trace_with_opname
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict
@@ -43,7 +43,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()
- @trace(opname="sendToDevice")
+ @trace_with_opname("sendToDevice")
def on_PUT(
self, request: SynapseRequest, message_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 8bbf35148d..c2989765ce 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -37,7 +37,7 @@ from synapse.handlers.sync import (
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import trace_with_opname
from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder
@@ -210,7 +210,7 @@ class SyncRestServlet(RestServlet):
logger.debug("Event formatting complete")
return 200, response_content
- @trace(opname="sync.encode_response")
+ @trace_with_opname("sync.encode_response")
async def encode_response(
self,
time_now: int,
@@ -315,7 +315,7 @@ class SyncRestServlet(RestServlet):
]
}
- @trace(opname="sync.encode_joined")
+ @trace_with_opname("sync.encode_joined")
async def encode_joined(
self,
rooms: List[JoinedSyncResult],
@@ -340,7 +340,7 @@ class SyncRestServlet(RestServlet):
return joined
- @trace(opname="sync.encode_invited")
+ @trace_with_opname("sync.encode_invited")
async def encode_invited(
self,
rooms: List[InvitedSyncResult],
@@ -371,7 +371,7 @@ class SyncRestServlet(RestServlet):
return invited
- @trace(opname="sync.encode_knocked")
+ @trace_with_opname("sync.encode_knocked")
async def encode_knocked(
self,
rooms: List[KnockedSyncResult],
@@ -420,7 +420,7 @@ class SyncRestServlet(RestServlet):
return knocked
- @trace(opname="sync.encode_archived")
+ @trace_with_opname("sync.encode_archived")
async def encode_archived(
self,
rooms: List[ArchivedSyncResult],
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index adde5d0978..7a6ed332aa 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -669,7 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
@trace
async def get_user_devices_from_cache(
- self, query_list: List[Tuple[str, str]]
+ self, query_list: List[Tuple[str, Optional[str]]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 9b293475c8..60f622ad71 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -22,11 +22,14 @@ from typing import (
List,
Optional,
Tuple,
+ Union,
cast,
+ overload,
)
import attr
from canonicaljson import encode_canonical_json
+from typing_extensions import Literal
from synapse.api.constants import DeviceKeyAlgorithms
from synapse.appservice import (
@@ -113,7 +116,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.items():
- result = {"device_id": device_id}
+ result: JsonDict = {"device_id": device_id}
keys = device.keys
if keys:
@@ -156,6 +159,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
rv[user_id] = {}
for device_id, device_info in device_keys.items():
r = device_info.keys
+ if r is None:
+ continue
+
r["unsigned"] = {}
display_name = device_info.display_name
if display_name is not None:
@@ -164,13 +170,42 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return rv
+ @overload
+ async def get_e2e_device_keys_and_signatures(
+ self,
+ query_list: Collection[Tuple[str, Optional[str]]],
+ include_all_devices: Literal[False] = False,
+ ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
+ ...
+
+ @overload
+ async def get_e2e_device_keys_and_signatures(
+ self,
+ query_list: Collection[Tuple[str, Optional[str]]],
+ include_all_devices: bool = False,
+ include_deleted_devices: Literal[False] = False,
+ ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
+ ...
+
+ @overload
+ async def get_e2e_device_keys_and_signatures(
+ self,
+ query_list: Collection[Tuple[str, Optional[str]]],
+ include_all_devices: Literal[True],
+ include_deleted_devices: Literal[True],
+ ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+ ...
+
@trace
async def get_e2e_device_keys_and_signatures(
self,
- query_list: List[Tuple[str, Optional[str]]],
+ query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
- ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+ ) -> Union[
+ Dict[str, Dict[str, DeviceKeyLookupResult]],
+ Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]],
+ ]:
"""Fetch a list of device keys
Any cross-signatures made on the keys by the owner of the device are also
@@ -1044,7 +1079,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
db_autocommit = False
- row = await self.db_pool.runInteraction(
+ claim_row = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
_claim_e2e_one_time_key,
user_id,
@@ -1052,11 +1087,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
algorithm,
db_autocommit=db_autocommit,
)
- if row:
+ if claim_row:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
- device_results[row[0]] = row[1]
+ device_results[claim_row[0]] = claim_row[1]
continue
# No one-time key available, so see if there's a fallback
|