diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 8ce27f49ec..d46f05f426 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,8 +15,6 @@
from twisted.internet import defer
-from ._base import BaseHandler
-
from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.metrics import Measure
@@ -35,11 +33,13 @@ logger = logging.getLogger(__name__)
RoomMember = namedtuple("RoomMember", ("room_id", "user"))
-class TypingNotificationHandler(BaseHandler):
+class TypingHandler(object):
def __init__(self, hs):
- super(TypingNotificationHandler, self).__init__(hs)
-
- self.homeserver = hs
+ self.store = hs.get_datastore()
+ self.server_name = hs.config.server_name
+ self.auth = hs.get_auth()
+ self.is_mine = hs.is_mine
+ self.notifier = hs.get_notifier()
self.clock = hs.get_clock()
@@ -67,7 +67,7 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks
def started_typing(self, target_user, auth_user, room_id, timeout):
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user:
@@ -110,7 +110,7 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks
def stopped_typing(self, target_user, auth_user, room_id):
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user:
@@ -132,7 +132,7 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
- if self.hs.is_mine(user):
+ if self.is_mine(user):
member = RoomMember(room_id=room_id, user=user)
yield self._stopped_typing(member)
@@ -157,32 +157,26 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks
def _push_update(self, room_id, user, typing):
- localusers = set()
- remotedomains = set()
-
- rm_handler = self.homeserver.get_handlers().room_member_handler
- yield rm_handler.fetch_room_distributions_into(
- room_id, localusers=localusers, remotedomains=remotedomains
- )
-
- if localusers:
- self._push_update_local(
- room_id=room_id,
- user=user,
- typing=typing
- )
+ domains = yield self.store.get_joined_hosts_for_room(room_id)
deferreds = []
- for domain in remotedomains:
- deferreds.append(self.federation.send_edu(
- destination=domain,
- edu_type="m.typing",
- content={
- "room_id": room_id,
- "user_id": user.to_string(),
- "typing": typing,
- },
- ))
+ for domain in domains:
+ if domain == self.server_name:
+ self._push_update_local(
+ room_id=room_id,
+ user=user,
+ typing=typing
+ )
+ else:
+ deferreds.append(self.federation.send_edu(
+ destination=domain,
+ edu_type="m.typing",
+ content={
+ "room_id": room_id,
+ "user_id": user.to_string(),
+ "typing": typing,
+ },
+ ))
yield defer.DeferredList(deferreds, consumeErrors=True)
@@ -191,14 +185,9 @@ class TypingNotificationHandler(BaseHandler):
room_id = content["room_id"]
user = UserID.from_string(content["user_id"])
- localusers = set()
-
- rm_handler = self.homeserver.get_handlers().room_member_handler
- yield rm_handler.fetch_room_distributions_into(
- room_id, localusers=localusers
- )
+ domains = yield self.store.get_joined_hosts_for_room(room_id)
- if localusers:
+ if self.server_name in domains:
self._push_update_local(
room_id=room_id,
user=user,
@@ -238,22 +227,14 @@ class TypingNotificationEventSource(object):
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
- self._handler = None
- self._room_member_handler = None
-
- def handler(self):
- # Avoid cyclic dependency in handler setup
- if not self._handler:
- self._handler = self.hs.get_handlers().typing_notification_handler
- return self._handler
-
- def room_member_handler(self):
- if not self._room_member_handler:
- self._room_member_handler = self.hs.get_handlers().room_member_handler
- return self._room_member_handler
+ # We can't call get_typing_handler here because there's a cycle:
+ #
+ # Typing -> Notifier -> TypingNotificationEventSource -> Typing
+ #
+ self.get_typing_handler = hs.get_typing_handler
def _make_event_for(self, room_id):
- typing = self.handler()._room_typing[room_id]
+ typing = self.get_typing_handler()._room_typing[room_id]
return {
"type": "m.typing",
"room_id": room_id,
@@ -265,7 +246,7 @@ class TypingNotificationEventSource(object):
def get_new_events(self, from_key, room_ids, **kwargs):
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
- handler = self.handler()
+ handler = self.get_typing_handler()
events = []
for room_id in room_ids:
@@ -279,7 +260,7 @@ class TypingNotificationEventSource(object):
return events, handler._latest_room_serial
def get_current_key(self):
- return self.handler()._latest_room_serial
+ return self.get_typing_handler()._latest_room_serial
def get_pagination_rows(self, user, pagination_config, key):
return ([], pagination_config.from_key)
|