diff options
Diffstat (limited to 'synapse/handlers')
-rw-r--r-- | synapse/handlers/devicemessage.py | 117 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 19 | ||||
-rw-r--r-- | synapse/handlers/message.py | 44 | ||||
-rw-r--r-- | synapse/handlers/presence.py | 33 | ||||
-rw-r--r-- | synapse/handlers/typing.py | 9 |
5 files changed, 167 insertions, 55 deletions
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py new file mode 100644 index 0000000000..c5368e5df2 --- /dev/null +++ b/synapse/handlers/devicemessage.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.types import get_domain_from_id +from synapse.util.stringutils import random_string + + +logger = logging.getLogger(__name__) + + +class DeviceMessageHandler(object): + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + self.is_mine_id = hs.is_mine_id + self.federation = hs.get_replication_layer() + + self.federation.register_edu_handler( + "m.direct_to_device", self.on_direct_to_device_edu + ) + + @defer.inlineCallbacks + def on_direct_to_device_edu(self, origin, content): + local_messages = {} + sender_user_id = content["sender"] + if origin != get_domain_from_id(sender_user_id): + logger.warn( + "Dropping device message from %r with spoofed sender %r", + origin, sender_user_id + ) + message_type = content["type"] + message_id = content["message_id"] + for user_id, by_device in content["messages"].items(): + messages_by_device = { + device_id: { + "content": message_content, + "type": message_type, + "sender": sender_user_id, + } + for device_id, message_content in by_device.items() + } + if messages_by_device: + local_messages[user_id] = messages_by_device + + stream_id = yield self.store.add_messages_from_remote_to_device_inbox( + origin, message_id, local_messages + ) + + self.notifier.on_new_event( + "to_device_key", stream_id, users=local_messages.keys() + ) + + @defer.inlineCallbacks + def send_device_message(self, sender_user_id, message_type, messages): + + local_messages = {} + remote_messages = {} + for user_id, by_device in messages.items(): + if self.is_mine_id(user_id): + messages_by_device = { + device_id: { + "content": message_content, + "type": message_type, + "sender": sender_user_id, + } + for device_id, message_content in by_device.items() + } + if messages_by_device: + local_messages[user_id] = messages_by_device + else: + destination = get_domain_from_id(user_id) + remote_messages.setdefault(destination, {})[user_id] = by_device + + message_id = random_string(16) + + remote_edu_contents = {} + for destination, messages in remote_messages.items(): + remote_edu_contents[destination] = { + "messages": messages, + "sender": sender_user_id, + "type": message_type, + "message_id": message_id, + } + + stream_id = yield self.store.add_messages_to_device_inbox( + local_messages, remote_edu_contents + ) + + self.notifier.on_new_event( + "to_device_key", stream_id, users=local_messages.keys() + ) + + for destination in remote_messages.keys(): + # Enqueue a new federation transaction to send the new + # device messages to each remote destination. + self.federation.send_device_messages(destination) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index dc90a5dde4..8a1038c44a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -832,11 +832,13 @@ class FederationHandler(BaseHandler): new_pdu = event - message_handler = self.hs.get_handlers().message_handler - destinations = yield message_handler.get_joined_hosts_for_room_from_state( - context + users_in_room = yield self.store.get_joined_users_from_context(event, context) + + destinations = set( + get_domain_from_id(user_id) for user_id in users_in_room + if not self.hs.is_mine_id(user_id) ) - destinations = set(destinations) + destinations.discard(origin) logger.debug( @@ -1055,11 +1057,12 @@ class FederationHandler(BaseHandler): new_pdu = event - message_handler = self.hs.get_handlers().message_handler - destinations = yield message_handler.get_joined_hosts_for_room_from_state( - context + users_in_room = yield self.store.get_joined_users_from_context(event, context) + + destinations = set( + get_domain_from_id(user_id) for user_id in users_in_room + if not self.hs.is_mine_id(user_id) ) - destinations = set(destinations) destinations.discard(origin) logger.debug( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 3577db0595..178209a209 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -30,7 +30,6 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.metrics import measure_func -from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -945,7 +944,12 @@ class MessageHandler(BaseHandler): event_stream_id, max_stream_id ) - destinations = yield self.get_joined_hosts_for_room_from_state(context) + users_in_room = yield self.store.get_joined_users_from_context(event, context) + + destinations = [ + get_domain_from_id(user_id) for user_id in users_in_room + if not self.hs.is_mine_id(user_id) + ] @defer.inlineCallbacks def _notify(): @@ -963,39 +967,3 @@ class MessageHandler(BaseHandler): preserve_fn(federation_handler.handle_new_event)( event, destinations=destinations, ) - - def get_joined_hosts_for_room_from_state(self, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - return self._get_joined_hosts_for_room_from_state( - state_group, context.current_state_ids - ) - - @cachedInlineCallbacks(num_args=1, cache_context=True) - def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids, - cache_context): - - # Don't bother getting state for people on the same HS - current_state = yield self.store.get_events([ - e_id for key, e_id in current_state_ids.items() - if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1]) - ]) - - destinations = set() - for e in current_state.itervalues(): - try: - if e.type == EventTypes.Member: - if e.content["membership"] == Membership.JOIN: - destinations.add(get_domain_from_id(e.state_key)) - except SynapseError: - logger.warn( - "Failed to get destination from event %s", e.event_id - ) - - defer.returnValue(destinations) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index cf82a2336e..7a3c16a8aa 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -52,6 +52,11 @@ bump_active_time_counter = metrics.register_counter("bump_active_time") get_updates_counter = metrics.register_counter("get_updates", labels=["type"]) +notify_reason_counter = metrics.register_counter("notify_reason", labels=["reason"]) +state_transition_counter = metrics.register_counter( + "state_transition", labels=["from", "to"] +) + # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them # "currently_active" @@ -646,6 +651,13 @@ class PresenceHandler(object): ) continue + if get_domain_from_id(user_id) != origin: + logger.info( + "Got presence update from %r with bad 'user_id': %r", + origin, user_id, + ) + continue + presence_state = push.get("presence", None) if not presence_state: logger.info( @@ -939,27 +951,32 @@ class PresenceHandler(object): def should_notify(old_state, new_state): """Decides if a presence state change should be sent to interested parties. """ + if old_state == new_state: + return False + if old_state.status_msg != new_state.status_msg: + notify_reason_counter.inc("status_msg_change") return True - if old_state.state == PresenceState.ONLINE: - if new_state.state != PresenceState.ONLINE: - # Always notify for online -> anything - return True + if old_state.state != new_state.state: + notify_reason_counter.inc("state_change") + state_transition_counter.inc(old_state.state, new_state.state) + return True + if old_state.state == PresenceState.ONLINE: if new_state.currently_active != old_state.currently_active: + notify_reason_counter.inc("current_active_change") return True if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: # Only notify about last active bumps if we're not currently acive - if not (old_state.currently_active and new_state.currently_active): + if not new_state.currently_active: + notify_reason_counter.inc("last_active_change_online") return True elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: # Always notify for a transition where last active gets bumped. - return True - - if old_state.state != new_state.state: + notify_reason_counter.inc("last_active_change_not_online") return True return False diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0b530b9034..3b687957dd 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -199,7 +199,14 @@ class TypingHandler(object): user_id = content["user_id"] # Check that the string is a valid user id - UserID.from_string(user_id) + user = UserID.from_string(user_id) + + if user.domain != origin: + logger.info( + "Got typing update from %r with bad 'user_id': %r", + origin, user_id, + ) + return users = yield self.state.get_current_user_in_room(room_id) domains = set(get_domain_from_id(u) for u in users) |