diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 64ef7f63ab..9cac5a8463 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict
+from typing import TYPE_CHECKING, Any, Dict
from synapse.api.errors import SynapseError
from synapse.logging.context import run_in_background
@@ -24,18 +24,22 @@ from synapse.logging.opentracing import (
set_tag,
start_active_span,
)
-from synapse.types import UserID, get_domain_from_id
+from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.stringutils import random_string
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
+
logger = logging.getLogger(__name__)
class DeviceMessageHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
"""
Args:
- hs (synapse.server.HomeServer): server
+ hs: server
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@@ -48,7 +52,7 @@ class DeviceMessageHandler:
self._device_list_updater = hs.get_device_handler().device_list_updater
- async def on_direct_to_device_edu(self, origin, content):
+ async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
@@ -95,7 +99,7 @@ class DeviceMessageHandler:
message_type: str,
sender_user_id: str,
by_device: Dict[str, Dict[str, Any]],
- ):
+ ) -> None:
"""Checks inbound device messages for unknown remote devices, and if
found marks the remote cache for the user as stale.
"""
@@ -138,11 +142,16 @@ class DeviceMessageHandler:
self._device_list_updater.user_device_resync, sender_user_id
)
- async def send_device_message(self, sender_user_id, message_type, messages):
+ async def send_device_message(
+ self,
+ sender_user_id: str,
+ message_type: str,
+ messages: Dict[str, Dict[str, JsonDict]],
+ ) -> None:
set_tag("number_of_messages", len(messages))
set_tag("sender", sender_user_id)
local_messages = {}
- remote_messages = {}
+ remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
for user_id, by_device in messages.items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
|