summary refs log tree commit diff
path: root/synapse/handlers/typing.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/typing.py')
-rw-r--r--synapse/handlers/typing.py93
1 files changed, 37 insertions, 56 deletions
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py

index 8ce27f49ec..d46f05f426 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py
@@ -15,8 +15,6 @@ from twisted.internet import defer -from ._base import BaseHandler - from synapse.api.errors import SynapseError, AuthError from synapse.util.logcontext import PreserveLoggingContext from synapse.util.metrics import Measure @@ -35,11 +33,13 @@ logger = logging.getLogger(__name__) RoomMember = namedtuple("RoomMember", ("room_id", "user")) -class TypingNotificationHandler(BaseHandler): +class TypingHandler(object): def __init__(self, hs): - super(TypingNotificationHandler, self).__init__(hs) - - self.homeserver = hs + self.store = hs.get_datastore() + self.server_name = hs.config.server_name + self.auth = hs.get_auth() + self.is_mine = hs.is_mine + self.notifier = hs.get_notifier() self.clock = hs.get_clock() @@ -67,7 +67,7 @@ class TypingNotificationHandler(BaseHandler): @defer.inlineCallbacks def started_typing(self, target_user, auth_user, room_id, timeout): - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") if target_user != auth_user: @@ -110,7 +110,7 @@ class TypingNotificationHandler(BaseHandler): @defer.inlineCallbacks def stopped_typing(self, target_user, auth_user, room_id): - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") if target_user != auth_user: @@ -132,7 +132,7 @@ class TypingNotificationHandler(BaseHandler): @defer.inlineCallbacks def user_left_room(self, user, room_id): - if self.hs.is_mine(user): + if self.is_mine(user): member = RoomMember(room_id=room_id, user=user) yield self._stopped_typing(member) @@ -157,32 +157,26 @@ class TypingNotificationHandler(BaseHandler): @defer.inlineCallbacks def _push_update(self, room_id, user, typing): - localusers = set() - remotedomains = set() - - rm_handler = self.homeserver.get_handlers().room_member_handler - yield rm_handler.fetch_room_distributions_into( - room_id, localusers=localusers, remotedomains=remotedomains - ) - - if localusers: - self._push_update_local( - room_id=room_id, - user=user, - typing=typing - ) + domains = yield self.store.get_joined_hosts_for_room(room_id) deferreds = [] - for domain in remotedomains: - deferreds.append(self.federation.send_edu( - destination=domain, - edu_type="m.typing", - content={ - "room_id": room_id, - "user_id": user.to_string(), - "typing": typing, - }, - )) + for domain in domains: + if domain == self.server_name: + self._push_update_local( + room_id=room_id, + user=user, + typing=typing + ) + else: + deferreds.append(self.federation.send_edu( + destination=domain, + edu_type="m.typing", + content={ + "room_id": room_id, + "user_id": user.to_string(), + "typing": typing, + }, + )) yield defer.DeferredList(deferreds, consumeErrors=True) @@ -191,14 +185,9 @@ class TypingNotificationHandler(BaseHandler): room_id = content["room_id"] user = UserID.from_string(content["user_id"]) - localusers = set() - - rm_handler = self.homeserver.get_handlers().room_member_handler - yield rm_handler.fetch_room_distributions_into( - room_id, localusers=localusers - ) + domains = yield self.store.get_joined_hosts_for_room(room_id) - if localusers: + if self.server_name in domains: self._push_update_local( room_id=room_id, user=user, @@ -238,22 +227,14 @@ class TypingNotificationEventSource(object): def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() - self._handler = None - self._room_member_handler = None - - def handler(self): - # Avoid cyclic dependency in handler setup - if not self._handler: - self._handler = self.hs.get_handlers().typing_notification_handler - return self._handler - - def room_member_handler(self): - if not self._room_member_handler: - self._room_member_handler = self.hs.get_handlers().room_member_handler - return self._room_member_handler + # We can't call get_typing_handler here because there's a cycle: + # + # Typing -> Notifier -> TypingNotificationEventSource -> Typing + # + self.get_typing_handler = hs.get_typing_handler def _make_event_for(self, room_id): - typing = self.handler()._room_typing[room_id] + typing = self.get_typing_handler()._room_typing[room_id] return { "type": "m.typing", "room_id": room_id, @@ -265,7 +246,7 @@ class TypingNotificationEventSource(object): def get_new_events(self, from_key, room_ids, **kwargs): with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) - handler = self.handler() + handler = self.get_typing_handler() events = [] for room_id in room_ids: @@ -279,7 +260,7 @@ class TypingNotificationEventSource(object): return events, handler._latest_room_serial def get_current_key(self): - return self.handler()._latest_room_serial + return self.get_typing_handler()._latest_room_serial def get_pagination_rows(self, user, pagination_config, key): return ([], pagination_config.from_key)