summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13328.misc1
-rw-r--r--synapse/federation/federation_client.py2
-rw-r--r--synapse/federation/transport/client.py2
-rw-r--r--synapse/handlers/e2e_keys.py16
-rw-r--r--synapse/logging/opentracing.py50
-rw-r--r--synapse/replication/http/_base.py4
-rw-r--r--synapse/rest/client/keys.py4
-rw-r--r--synapse/rest/client/room_keys.py13
-rw-r--r--synapse/rest/client/sendtodevice.py4
-rw-r--r--synapse/rest/client/sync.py12
-rw-r--r--synapse/storage/databases/main/devices.py2
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py47
12 files changed, 102 insertions, 55 deletions
diff --git a/changelog.d/13328.misc b/changelog.d/13328.misc
new file mode 100644
index 0000000000..d15fb5fc37
--- /dev/null
+++ b/changelog.d/13328.misc
@@ -0,0 +1 @@
+Add type hints to `trace` decorator.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 66e6305562..7c450ecad0 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -217,7 +217,7 @@ class FederationClient(FederationBase):
         )
 
     async def claim_client_keys(
-        self, destination: str, content: JsonDict, timeout: int
+        self, destination: str, content: JsonDict, timeout: Optional[int]
     ) -> JsonDict:
         """Claims one-time keys for a device hosted on a remote server.
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 9e84bd677e..32074b8ca6 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -619,7 +619,7 @@ class TransportLayerClient:
         )
 
     async def claim_client_keys(
-        self, destination: str, query_content: JsonDict, timeout: int
+        self, destination: str, query_content: JsonDict, timeout: Optional[int]
     ) -> JsonDict:
         """Claim one-time keys for a list of devices hosted on a remote server.
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 52bb5c9c55..84c28c480e 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -92,7 +92,11 @@ class E2eKeysHandler:
 
     @trace
     async def query_devices(
-        self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
+        self,
+        query_body: JsonDict,
+        timeout: int,
+        from_user_id: str,
+        from_device_id: Optional[str],
     ) -> JsonDict:
         """Handle a device key query from a client
 
@@ -120,9 +124,7 @@ class E2eKeysHandler:
                 the number of in-flight queries at a time.
         """
         async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
-            device_keys_query: Dict[str, Iterable[str]] = query_body.get(
-                "device_keys", {}
-            )
+            device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
 
             # separate users by domain.
             # make a map from domain to user_id to device_ids
@@ -392,7 +394,7 @@ class E2eKeysHandler:
 
     @trace
     async def query_local_devices(
-        self, query: Dict[str, Optional[List[str]]]
+        self, query: Mapping[str, Optional[List[str]]]
     ) -> Dict[str, Dict[str, dict]]:
         """Get E2E device keys for local users
 
@@ -461,7 +463,7 @@ class E2eKeysHandler:
 
     @trace
     async def claim_one_time_keys(
-        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
     ) -> JsonDict:
         local_query: List[Tuple[str, str, str]] = []
         remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 50c57940f9..17e729f0c7 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -84,14 +84,13 @@ the function becomes the operation name for the span.
        return something_usual_and_useful
 
 
-Operation names can be explicitly set for a function by passing the
-operation name to ``trace``
+Operation names can be explicitly set for a function by using ``trace_with_opname``:
 
 .. code-block:: python
 
-   from synapse.logging.opentracing import trace
+   from synapse.logging.opentracing import trace_with_opname
 
-   @trace(opname="a_better_operation_name")
+   @trace_with_opname("a_better_operation_name")
    def interesting_badly_named_function(*args, **kwargs):
        # Does all kinds of cool and expected things
        return something_usual_and_useful
@@ -798,33 +797,31 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
 # Tracing decorators
 
 
-def trace(func=None, opname: Optional[str] = None):
+def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
     """
-    Decorator to trace a function.
-    Sets the operation name to that of the function's or that given
-    as operation_name. See the module's doc string for usage
-    examples.
+    Decorator to trace a function with a custom opname.
+
+    See the module's doc string for usage examples.
+
     """
 
-    def decorator(func):
+    def decorator(func: Callable[P, R]) -> Callable[P, R]:
         if opentracing is None:
             return func  # type: ignore[unreachable]
 
-        _opname = opname if opname else func.__name__
-
         if inspect.iscoroutinefunction(func):
 
             @wraps(func)
-            async def _trace_inner(*args, **kwargs):
-                with start_active_span(_opname):
-                    return await func(*args, **kwargs)
+            async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+                with start_active_span(opname):
+                    return await func(*args, **kwargs)  # type: ignore[misc]
 
         else:
             # The other case here handles both sync functions and those
             # decorated with inlineDeferred.
             @wraps(func)
-            def _trace_inner(*args, **kwargs):
-                scope = start_active_span(_opname)
+            def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+                scope = start_active_span(opname)
                 scope.__enter__()
 
                 try:
@@ -858,12 +855,21 @@ def trace(func=None, opname: Optional[str] = None):
                     scope.__exit__(type(e), None, e.__traceback__)
                     raise
 
-        return _trace_inner
+        return _trace_inner  # type: ignore[return-value]
 
-    if func:
-        return decorator(func)
-    else:
-        return decorator
+    return decorator
+
+
+def trace(func: Callable[P, R]) -> Callable[P, R]:
+    """
+    Decorator to trace a function.
+
+    Sets the operation name to that of the function's name.
+
+    See the module's doc string for usage examples.
+    """
+
+    return trace_with_opname(func.__name__)(func)
 
 
 def tag_args(func: Callable[P, R]) -> Callable[P, R]:
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index a4ae4040c3..561ad5bf04 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -29,7 +29,7 @@ from synapse.http import RequestTimedOutError
 from synapse.http.server import HttpServer, is_method_cancellable
 from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import trace_with_opname
 from synapse.types import JsonDict
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.stringutils import random_string
@@ -196,7 +196,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                 "ascii"
             )
 
-        @trace(opname="outgoing_replication_request")
+        @trace_with_opname("outgoing_replication_request")
         async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
             with outgoing_gauge.track_inprogress():
                 if instance_name == local_instance_name:
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index ce806e3c11..eb1b85721f 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -26,7 +26,7 @@ from synapse.http.servlet import (
     parse_string,
 )
 from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
 from synapse.types import JsonDict, StreamToken
 
 from ._base import client_patterns, interactive_auth_handler
@@ -71,7 +71,7 @@ class KeyUploadServlet(RestServlet):
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
         self.device_handler = hs.get_device_handler()
 
-    @trace(opname="upload_keys")
+    @trace_with_opname("upload_keys")
     async def on_POST(
         self, request: SynapseRequest, device_id: Optional[str]
     ) -> Tuple[int, JsonDict]:
diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py
index 37e39570f6..f7081f638e 100644
--- a/synapse/rest/client/room_keys.py
+++ b/synapse/rest/client/room_keys.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple, cast
 
 from synapse.api.errors import Codes, NotFoundError, SynapseError
 from synapse.http.server import HttpServer
@@ -127,7 +127,7 @@ class RoomKeysServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
         user_id = requester.user.to_string()
         body = parse_json_object_from_request(request)
-        version = parse_string(request, "version")
+        version = parse_string(request, "version", required=True)
 
         if session_id:
             body = {"sessions": {session_id: body}}
@@ -196,8 +196,11 @@ class RoomKeysServlet(RestServlet):
         user_id = requester.user.to_string()
         version = parse_string(request, "version", required=True)
 
-        room_keys = await self.e2e_room_keys_handler.get_room_keys(
-            user_id, version, room_id, session_id
+        room_keys = cast(
+            JsonDict,
+            await self.e2e_room_keys_handler.get_room_keys(
+                user_id, version, room_id, session_id
+            ),
         )
 
         # Convert room_keys to the right format to return.
@@ -240,7 +243,7 @@ class RoomKeysServlet(RestServlet):
 
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
         user_id = requester.user.to_string()
-        version = parse_string(request, "version")
+        version = parse_string(request, "version", required=True)
 
         ret = await self.e2e_room_keys_handler.delete_room_keys(
             user_id, version, room_id, session_id
diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py
index 3322c8ef48..1a8e9a96d4 100644
--- a/synapse/rest/client/sendtodevice.py
+++ b/synapse/rest/client/sendtodevice.py
@@ -19,7 +19,7 @@ from synapse.http import servlet
 from synapse.http.server import HttpServer
 from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import set_tag, trace
+from synapse.logging.opentracing import set_tag, trace_with_opname
 from synapse.rest.client.transactions import HttpTransactionCache
 from synapse.types import JsonDict
 
@@ -43,7 +43,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
         self.txns = HttpTransactionCache(hs)
         self.device_message_handler = hs.get_device_message_handler()
 
-    @trace(opname="sendToDevice")
+    @trace_with_opname("sendToDevice")
     def on_PUT(
         self, request: SynapseRequest, message_type: str, txn_id: str
     ) -> Awaitable[Tuple[int, JsonDict]]:
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 8bbf35148d..c2989765ce 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -37,7 +37,7 @@ from synapse.handlers.sync import (
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import trace_with_opname
 from synapse.types import JsonDict, StreamToken
 from synapse.util import json_decoder
 
@@ -210,7 +210,7 @@ class SyncRestServlet(RestServlet):
         logger.debug("Event formatting complete")
         return 200, response_content
 
-    @trace(opname="sync.encode_response")
+    @trace_with_opname("sync.encode_response")
     async def encode_response(
         self,
         time_now: int,
@@ -315,7 +315,7 @@ class SyncRestServlet(RestServlet):
             ]
         }
 
-    @trace(opname="sync.encode_joined")
+    @trace_with_opname("sync.encode_joined")
     async def encode_joined(
         self,
         rooms: List[JoinedSyncResult],
@@ -340,7 +340,7 @@ class SyncRestServlet(RestServlet):
 
         return joined
 
-    @trace(opname="sync.encode_invited")
+    @trace_with_opname("sync.encode_invited")
     async def encode_invited(
         self,
         rooms: List[InvitedSyncResult],
@@ -371,7 +371,7 @@ class SyncRestServlet(RestServlet):
 
         return invited
 
-    @trace(opname="sync.encode_knocked")
+    @trace_with_opname("sync.encode_knocked")
     async def encode_knocked(
         self,
         rooms: List[KnockedSyncResult],
@@ -420,7 +420,7 @@ class SyncRestServlet(RestServlet):
 
         return knocked
 
-    @trace(opname="sync.encode_archived")
+    @trace_with_opname("sync.encode_archived")
     async def encode_archived(
         self,
         rooms: List[ArchivedSyncResult],
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index adde5d0978..7a6ed332aa 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -669,7 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
 
     @trace
     async def get_user_devices_from_cache(
-        self, query_list: List[Tuple[str, str]]
+        self, query_list: List[Tuple[str, Optional[str]]]
     ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
         """Get the devices (and keys if any) for remote users from the cache.
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 9b293475c8..60f622ad71 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -22,11 +22,14 @@ from typing import (
     List,
     Optional,
     Tuple,
+    Union,
     cast,
+    overload,
 )
 
 import attr
 from canonicaljson import encode_canonical_json
+from typing_extensions import Literal
 
 from synapse.api.constants import DeviceKeyAlgorithms
 from synapse.appservice import (
@@ -113,7 +116,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             user_devices = devices[user_id]
             results = []
             for device_id, device in user_devices.items():
-                result = {"device_id": device_id}
+                result: JsonDict = {"device_id": device_id}
 
                 keys = device.keys
                 if keys:
@@ -156,6 +159,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             rv[user_id] = {}
             for device_id, device_info in device_keys.items():
                 r = device_info.keys
+                if r is None:
+                    continue
+
                 r["unsigned"] = {}
                 display_name = device_info.display_name
                 if display_name is not None:
@@ -164,13 +170,42 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         return rv
 
+    @overload
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: Collection[Tuple[str, Optional[str]]],
+        include_all_devices: Literal[False] = False,
+    ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
+        ...
+
+    @overload
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: Collection[Tuple[str, Optional[str]]],
+        include_all_devices: bool = False,
+        include_deleted_devices: Literal[False] = False,
+    ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
+        ...
+
+    @overload
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: Collection[Tuple[str, Optional[str]]],
+        include_all_devices: Literal[True],
+        include_deleted_devices: Literal[True],
+    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+        ...
+
     @trace
     async def get_e2e_device_keys_and_signatures(
         self,
-        query_list: List[Tuple[str, Optional[str]]],
+        query_list: Collection[Tuple[str, Optional[str]]],
         include_all_devices: bool = False,
         include_deleted_devices: bool = False,
-    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+    ) -> Union[
+        Dict[str, Dict[str, DeviceKeyLookupResult]],
+        Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]],
+    ]:
         """Fetch a list of device keys
 
         Any cross-signatures made on the keys by the owner of the device are also
@@ -1044,7 +1079,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
                 db_autocommit = False
 
-            row = await self.db_pool.runInteraction(
+            claim_row = await self.db_pool.runInteraction(
                 "claim_e2e_one_time_keys",
                 _claim_e2e_one_time_key,
                 user_id,
@@ -1052,11 +1087,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 algorithm,
                 db_autocommit=db_autocommit,
             )
-            if row:
+            if claim_row:
                 device_results = results.setdefault(user_id, {}).setdefault(
                     device_id, {}
                 )
-                device_results[row[0]] = row[1]
+                device_results[claim_row[0]] = claim_row[1]
                 continue
 
             # No one-time key available, so see if there's a fallback