summary refs log tree commit diff
path: root/synapse/handlers/typing.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/typing.py')
-rw-r--r--synapse/handlers/typing.py143
1 files changed, 66 insertions, 77 deletions
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 8ce27f49ec..861b8f7989 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,8 +15,6 @@
 
 from twisted.internet import defer
 
-from ._base import BaseHandler
-
 from synapse.api.errors import SynapseError, AuthError
 from synapse.util.logcontext import PreserveLoggingContext
 from synapse.util.metrics import Measure
@@ -32,14 +30,16 @@ logger = logging.getLogger(__name__)
 
 # A tiny object useful for storing a user's membership in a room, as a mapping
 # key
-RoomMember = namedtuple("RoomMember", ("room_id", "user"))
+RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
 
 
-class TypingNotificationHandler(BaseHandler):
+class TypingHandler(object):
     def __init__(self, hs):
-        super(TypingNotificationHandler, self).__init__(hs)
-
-        self.homeserver = hs
+        self.store = hs.get_datastore()
+        self.server_name = hs.config.server_name
+        self.auth = hs.get_auth()
+        self.is_mine_id = hs.is_mine_id
+        self.notifier = hs.get_notifier()
 
         self.clock = hs.get_clock()
 
@@ -67,20 +67,23 @@ class TypingNotificationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def started_typing(self, target_user, auth_user, room_id, timeout):
-        if not self.hs.is_mine(target_user):
+        target_user_id = target_user.to_string()
+        auth_user_id = auth_user.to_string()
+
+        if not self.is_mine_id(target_user_id):
             raise SynapseError(400, "User is not hosted on this Home Server")
 
-        if target_user != auth_user:
+        if target_user_id != auth_user_id:
             raise AuthError(400, "Cannot set another user's typing state")
 
-        yield self.auth.check_joined_room(room_id, target_user.to_string())
+        yield self.auth.check_joined_room(room_id, target_user_id)
 
         logger.debug(
-            "%s has started typing in %s", target_user.to_string(), room_id
+            "%s has started typing in %s", target_user_id, room_id
         )
 
         until = self.clock.time_msec() + timeout
-        member = RoomMember(room_id=room_id, user=target_user)
+        member = RoomMember(room_id=room_id, user_id=target_user_id)
 
         was_present = member in self._member_typing_until
 
@@ -104,25 +107,28 @@ class TypingNotificationHandler(BaseHandler):
 
         yield self._push_update(
             room_id=room_id,
-            user=target_user,
+            user_id=target_user_id,
             typing=True,
         )
 
     @defer.inlineCallbacks
     def stopped_typing(self, target_user, auth_user, room_id):
-        if not self.hs.is_mine(target_user):
+        target_user_id = target_user.to_string()
+        auth_user_id = auth_user.to_string()
+
+        if not self.is_mine_id(target_user_id):
             raise SynapseError(400, "User is not hosted on this Home Server")
 
-        if target_user != auth_user:
+        if target_user_id != auth_user_id:
             raise AuthError(400, "Cannot set another user's typing state")
 
-        yield self.auth.check_joined_room(room_id, target_user.to_string())
+        yield self.auth.check_joined_room(room_id, target_user_id)
 
         logger.debug(
-            "%s has stopped typing in %s", target_user.to_string(), room_id
+            "%s has stopped typing in %s", target_user_id, room_id
         )
 
-        member = RoomMember(room_id=room_id, user=target_user)
+        member = RoomMember(room_id=room_id, user_id=target_user_id)
 
         if member in self._member_typing_timer:
             self.clock.cancel_call_later(self._member_typing_timer[member])
@@ -132,8 +138,9 @@ class TypingNotificationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def user_left_room(self, user, room_id):
-        if self.hs.is_mine(user):
-            member = RoomMember(room_id=room_id, user=user)
+        user_id = user.to_string()
+        if self.is_mine_id(user_id):
+            member = RoomMember(room_id=room_id, user_id=user_id)
             yield self._stopped_typing(member)
 
     @defer.inlineCallbacks
@@ -144,7 +151,7 @@ class TypingNotificationHandler(BaseHandler):
 
         yield self._push_update(
             room_id=member.room_id,
-            user=member.user,
+            user_id=member.user_id,
             typing=False,
         )
 
@@ -156,61 +163,53 @@ class TypingNotificationHandler(BaseHandler):
             del self._member_typing_timer[member]
 
     @defer.inlineCallbacks
-    def _push_update(self, room_id, user, typing):
-        localusers = set()
-        remotedomains = set()
-
-        rm_handler = self.homeserver.get_handlers().room_member_handler
-        yield rm_handler.fetch_room_distributions_into(
-            room_id, localusers=localusers, remotedomains=remotedomains
-        )
-
-        if localusers:
-            self._push_update_local(
-                room_id=room_id,
-                user=user,
-                typing=typing
-            )
+    def _push_update(self, room_id, user_id, typing):
+        domains = yield self.store.get_joined_hosts_for_room(room_id)
 
         deferreds = []
-        for domain in remotedomains:
-            deferreds.append(self.federation.send_edu(
-                destination=domain,
-                edu_type="m.typing",
-                content={
-                    "room_id": room_id,
-                    "user_id": user.to_string(),
-                    "typing": typing,
-                },
-            ))
+        for domain in domains:
+            if domain == self.server_name:
+                self._push_update_local(
+                    room_id=room_id,
+                    user_id=user_id,
+                    typing=typing
+                )
+            else:
+                deferreds.append(self.federation.send_edu(
+                    destination=domain,
+                    edu_type="m.typing",
+                    content={
+                        "room_id": room_id,
+                        "user_id": user_id,
+                        "typing": typing,
+                    },
+                ))
 
         yield defer.DeferredList(deferreds, consumeErrors=True)
 
     @defer.inlineCallbacks
     def _recv_edu(self, origin, content):
         room_id = content["room_id"]
-        user = UserID.from_string(content["user_id"])
+        user_id = content["user_id"]
 
-        localusers = set()
+        # Check that the string is a valid user id
+        UserID.from_string(user_id)
 
-        rm_handler = self.homeserver.get_handlers().room_member_handler
-        yield rm_handler.fetch_room_distributions_into(
-            room_id, localusers=localusers
-        )
+        domains = yield self.store.get_joined_hosts_for_room(room_id)
 
-        if localusers:
+        if self.server_name in domains:
             self._push_update_local(
                 room_id=room_id,
-                user=user,
+                user_id=user_id,
                 typing=content["typing"]
             )
 
-    def _push_update_local(self, room_id, user, typing):
+    def _push_update_local(self, room_id, user_id, typing):
         room_set = self._room_typing.setdefault(room_id, set())
         if typing:
-            room_set.add(user)
+            room_set.add(user_id)
         else:
-            room_set.discard(user)
+            room_set.discard(user_id)
 
         self._latest_room_serial += 1
         self._room_serials[room_id] = self._latest_room_serial
@@ -226,9 +225,7 @@ class TypingNotificationHandler(BaseHandler):
         for room_id, serial in self._room_serials.items():
             if last_id < serial and serial <= current_id:
                 typing = self._room_typing[room_id]
-                typing_bytes = json.dumps([
-                    u.to_string() for u in typing
-                ], ensure_ascii=False)
+                typing_bytes = json.dumps(list(typing), ensure_ascii=False)
                 rows.append((serial, room_id, typing_bytes))
         rows.sort()
         return rows
@@ -238,34 +235,26 @@ class TypingNotificationEventSource(object):
     def __init__(self, hs):
         self.hs = hs
         self.clock = hs.get_clock()
-        self._handler = None
-        self._room_member_handler = None
-
-    def handler(self):
-        # Avoid cyclic dependency in handler setup
-        if not self._handler:
-            self._handler = self.hs.get_handlers().typing_notification_handler
-        return self._handler
-
-    def room_member_handler(self):
-        if not self._room_member_handler:
-            self._room_member_handler = self.hs.get_handlers().room_member_handler
-        return self._room_member_handler
+        # We can't call get_typing_handler here because there's a cycle:
+        #
+        #   Typing -> Notifier -> TypingNotificationEventSource -> Typing
+        #
+        self.get_typing_handler = hs.get_typing_handler
 
     def _make_event_for(self, room_id):
-        typing = self.handler()._room_typing[room_id]
+        typing = self.get_typing_handler()._room_typing[room_id]
         return {
             "type": "m.typing",
             "room_id": room_id,
             "content": {
-                "user_ids": [u.to_string() for u in typing],
+                "user_ids": list(typing),
             },
         }
 
     def get_new_events(self, from_key, room_ids, **kwargs):
         with Measure(self.clock, "typing.get_new_events"):
             from_key = int(from_key)
-            handler = self.handler()
+            handler = self.get_typing_handler()
 
             events = []
             for room_id in room_ids:
@@ -279,7 +268,7 @@ class TypingNotificationEventSource(object):
             return events, handler._latest_room_serial
 
     def get_current_key(self):
-        return self.handler()._latest_room_serial
+        return self.get_typing_handler()._latest_room_serial
 
     def get_pagination_rows(self, user, pagination_config, key):
         return ([], pagination_config.from_key)