# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict
from synapse.api.constants import ToDeviceEventTypes
from synapse.api.errors import SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
SynapseTags,
get_active_span_text_map,
log_kv,
set_tag,
)
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class DeviceMessageHandler:
def __init__(self, hs: "HomeServer"):
"""
Args:
hs: server
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine = hs.is_mine
# We only need to poke the federation sender explicitly if its on the
# same instance. Other federation sender instances will get notified by
# `synapse.app.generic_worker.FederationSenderHandler` when it sees it
# in the to-device replication stream.
self.federation_sender = None
if hs.should_send_federation():
self.federation_sender = hs.get_federation_sender()
# If we can handle the to device EDUs we do so, otherwise we route them
# to the appropriate worker.
if hs.get_instance_name() in hs.config.worker.writers.to_device:
hs.get_federation_registry().register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu
)
else:
hs.get_federation_registry().register_instances_for_edu(
"m.direct_to_device",
hs.config.worker.writers.to_device,
)
# The handler to call when we think a user's device list might be out of
# sync. We do all device list resyncing on the master instance, so if
# we're on a worker we hit the device resync replication API.
if hs.config.worker.worker_app is None:
self._user_device_resync = (
hs.get_device_handler().device_list_updater.user_device_resync
)
else:
self._user_device_resync = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
# a rate limiter for room key requests. The keys are
# (sending_user_id, sending_device_id).
self._ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.rc_key_requests.per_second,
burst_count=hs.config.rc_key_requests.burst_count,
)
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
logger.warning(
"Dropping device message from %r with spoofed sender %r",
origin,
sender_user_id,
)
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
logger.warning("To-device message to non-local user %s", user_id)
raise SynapseError(400, "Not a user here")
if not by_device:
continue
# Ratelimit key requests by the sending user.
if message_type == ToDeviceEventTypes.RoomKeyRequest:
allowed, _ = await self._ratelimiter.can_do_action(
None, (sender_user_id, None)
)
if not allowed:
logger.info(
"Dropping room_key_request from %s to %s due to rate limit",
sender_user_id,
user_id,
)
continue
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
}
for device_id, message_content in by_device.items()
}
local_messages[user_id] = messages_by_device
await self._check_for_unknown_devices(
message_type, sender_user_id, by_device
)
stream_id = await self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages
)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
)
async def _check_for_unknown_devices(
self,
message_type: str,
sender_user_id: str,
by_device: Dict[str, Dict[str, Any]],
) -> None:
"""Checks inbound device messages for unknown remote devices, and if
found marks the remote cache for the user as stale.
"""
if message_type != "m.room_key_request":
return
# Get the sending device IDs
requesting_device_ids = set()
for message_content in by_device.values():
device_id = message_content.get("requesting_device_id")
requesting_device_ids.add(device_id)
# Check if we are tracking the devices of the remote user.
room_ids = await self.store.get_rooms_for_user(sender_user_id)
if not room_ids:
logger.info(
"Received device message from remote device we don't"
" share a room with: %s %s",
sender_user_id,
requesting_device_ids,
)
return
# If we are tracking check that we know about the sending
# devices.
cached_devices = await self.store.get_cached_devices_for_user(sender_user_id)
unknown_devices = requesting_device_ids - set(cached_devices)
if unknown_devices:
logger.info(
"Received device message from remote device not in our cache: %s %s",
sender_user_id,
unknown_devices,
)
await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
# Immediately attempt a resync in the background
run_in_background(self._user_device_resync, user_id=sender_user_id)
async def send_device_message(
self,
requester: Requester,
message_type: str,
messages: Dict[str, Dict[str, JsonDict]],
) -> None:
sender_user_id = requester.user.to_string()
message_id = random_string(16)
set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
log_kv({"number_of_to_device_messages": len(messages)})
set_tag("sender", sender_user_id)
local_messages = {}
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
for user_id, by_device in messages.items():
# Ratelimit local cross-user key requests by the sending device.
if (
message_type == ToDeviceEventTypes.RoomKeyRequest
and user_id != sender_user_id
):
allowed, _ = await self._ratelimiter.can_do_action(
requester, (sender_user_id, requester.device_id)
)
if not allowed:
logger.info(
"Dropping room_key_request from %s to %s due to rate limit",
sender_user_id,
user_id,
)
continue
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
"message_id": message_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
log_kv(
{
"user_id": user_id,
"device_id": list(messages_by_device),
}
)
else:
destination = get_domain_from_id(user_id)
remote_messages.setdefault(destination, {})[user_id] = by_device
context = get_active_span_text_map()
remote_edu_contents = {}
for destination, messages in remote_messages.items():
log_kv({"destination": destination})
remote_edu_contents[destination] = {
"messages": messages,
"sender": sender_user_id,
"type": message_type,
"message_id": message_id,
"org.matrix.opentracing_context": json_encoder.encode(context),
}
stream_id = await self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents
)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
)
if self.federation_sender:
for destination in remote_messages.keys():
# Enqueue a new federation transaction to send the new
# device messages to each remote destination.
self.federation_sender.send_device_messages(destination)