summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/typing.py177
1 files changed, 107 insertions, 70 deletions
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 0548b81c34..08313417b2 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -16,10 +16,9 @@
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, AuthError
-from synapse.util.logcontext import (
-    PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
-)
+from synapse.util.logcontext import preserve_fn
 from synapse.util.metrics import Measure
+from synapse.util.wheel_timer import WheelTimer
 from synapse.types import UserID, get_domain_from_id
 
 import logging
@@ -35,6 +34,13 @@ logger = logging.getLogger(__name__)
 RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
 
 
+# How often we expect remote servers to resend us presence.
+FEDERATION_TIMEOUT = 60 * 1000
+
+# How often to resend typing across federation.
+FEDERATION_PING_INTERVAL = 40 * 1000
+
+
 class TypingHandler(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
@@ -44,7 +50,10 @@ class TypingHandler(object):
         self.notifier = hs.get_notifier()
         self.state = hs.get_state_handler()
 
+        self.hs = hs
+
         self.clock = hs.get_clock()
+        self.wheel_timer = WheelTimer(bucket_size=5000)
 
         self.federation = hs.get_replication_layer()
 
@@ -53,7 +62,7 @@ class TypingHandler(object):
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
         self._member_typing_until = {}  # clock time we expect to stop
-        self._member_typing_timer = {}  # deferreds to manage theabove
+        self._member_last_federation_poke = {}
 
         # map room IDs to serial numbers
         self._room_serials = {}
@@ -61,12 +70,41 @@ class TypingHandler(object):
         # map room IDs to sets of users currently typing
         self._room_typing = {}
 
-    def tearDown(self):
-        """Cancels all the pending timers.
-        Normally this shouldn't be needed, but it's required from unit tests
-        to avoid a "Reactor was unclean" warning."""
-        for t in self._member_typing_timer.values():
-            self.clock.cancel_call_later(t)
+        self.clock.looping_call(
+            self._handle_timeouts,
+            5000,
+        )
+
+    def _handle_timeouts(self):
+        logger.info("Checking for typing timeouts")
+
+        now = self.clock.time_msec()
+
+        members = set(self.wheel_timer.fetch(now))
+
+        for member in members:
+            if not self.is_typing(member):
+                # Nothing to do if they're no longer typing
+                continue
+
+            until = self._member_typing_until.get(member, None)
+            if not until or until < now:
+                logger.info("Timing out typing for: %s", member.user_id)
+                preserve_fn(self._stopped_typing)(member)
+                continue
+
+            # Check if we need to resend a keep alive over federation for this
+            # user.
+            if self.hs.is_mine_id(member.user_id):
+                last_fed_poke = self._member_last_federation_poke.get(member, None)
+                if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL < now:
+                    preserve_fn(self._push_remote)(
+                        member=member,
+                        typing=True
+                    )
+
+    def is_typing(self, member):
+        return member.user_id in self._room_typing.get(member.room_id, [])
 
     @defer.inlineCallbacks
     def started_typing(self, target_user, auth_user, room_id, timeout):
@@ -85,23 +123,17 @@ class TypingHandler(object):
             "%s has started typing in %s", target_user_id, room_id
         )
 
-        until = self.clock.time_msec() + timeout
         member = RoomMember(room_id=room_id, user_id=target_user_id)
 
-        was_present = member in self._member_typing_until
-
-        if member in self._member_typing_timer:
-            self.clock.cancel_call_later(self._member_typing_timer[member])
+        was_present = member.user_id in self._room_typing.get(room_id, set())
 
-        def _cb():
-            logger.debug(
-                "%s has timed out in %s", target_user.to_string(), room_id
-            )
-            self._stopped_typing(member)
+        now = self.clock.time_msec()
+        self._member_typing_until[member] = now + timeout
 
-        self._member_typing_until[member] = until
-        self._member_typing_timer[member] = self.clock.call_later(
-            timeout / 1000.0, _cb
+        self.wheel_timer.insert(
+            now=now,
+            obj=member,
+            then=now + timeout,
         )
 
         if was_present:
@@ -109,8 +141,7 @@ class TypingHandler(object):
             defer.returnValue(None)
 
         yield self._push_update(
-            room_id=room_id,
-            user_id=target_user_id,
+            member=member,
             typing=True,
         )
 
@@ -133,10 +164,6 @@ class TypingHandler(object):
 
         member = RoomMember(room_id=room_id, user_id=target_user_id)
 
-        if member in self._member_typing_timer:
-            self.clock.cancel_call_later(self._member_typing_timer[member])
-            del self._member_typing_timer[member]
-
         yield self._stopped_typing(member)
 
     @defer.inlineCallbacks
@@ -148,57 +175,61 @@ class TypingHandler(object):
 
     @defer.inlineCallbacks
     def _stopped_typing(self, member):
-        if member not in self._member_typing_until:
+        if member.user_id not in self._room_typing.get(member.room_id, set()):
             # No point
             defer.returnValue(None)
 
+        self._member_typing_until.pop(member, None)
+        self._member_last_federation_poke.pop(member, None)
+
         yield self._push_update(
-            room_id=member.room_id,
-            user_id=member.user_id,
+            member=member,
             typing=False,
         )
 
-        del self._member_typing_until[member]
-
-        if member in self._member_typing_timer:
-            # Don't cancel it - either it already expired, or the real
-            # stopped_typing() will cancel it
-            del self._member_typing_timer[member]
+    @defer.inlineCallbacks
+    def _push_update(self, member, typing):
+        if self.hs.is_mine_id(member.user_id):
+            # Only send updates for changes to our own users.
+            yield self._push_remote(member, typing)
+
+        self._push_update_local(
+            member=member,
+            typing=typing
+        )
 
     @defer.inlineCallbacks
-    def _push_update(self, room_id, user_id, typing):
-        users = yield self.state.get_current_user_in_room(room_id)
-        domains = set(get_domain_from_id(u) for u in users)
+    def _push_remote(self, member, typing):
+        users = yield self.state.get_current_user_in_room(member.room_id)
+        self._member_last_federation_poke[member] = self.clock.time_msec()
+
+        now = self.clock.time_msec()
+        self.wheel_timer.insert(
+            now=now,
+            obj=member,
+            then=now + FEDERATION_PING_INTERVAL,
+        )
 
-        deferreds = []
-        for domain in domains:
-            if domain == self.server_name:
-                preserve_fn(self._push_update_local)(
-                    room_id=room_id,
-                    user_id=user_id,
-                    typing=typing
-                )
-            else:
-                deferreds.append(preserve_fn(self.federation.send_edu)(
+        for domain in set(get_domain_from_id(u) for u in users):
+            if domain != self.server_name:
+                self.federation.send_edu(
                     destination=domain,
                     edu_type="m.typing",
                     content={
-                        "room_id": room_id,
-                        "user_id": user_id,
+                        "room_id": member.room_id,
+                        "user_id": member.user_id,
                         "typing": typing,
                     },
-                    key=(room_id, user_id),
-                ))
-
-        yield preserve_context_over_deferred(
-            defer.DeferredList(deferreds, consumeErrors=True)
-        )
+                    key=member,
+                )
 
     @defer.inlineCallbacks
     def _recv_edu(self, origin, content):
         room_id = content["room_id"]
         user_id = content["user_id"]
 
+        member = RoomMember(user_id=user_id, room_id=room_id)
+
         # Check that the string is a valid user id
         user = UserID.from_string(user_id)
 
@@ -213,26 +244,32 @@ class TypingHandler(object):
         domains = set(get_domain_from_id(u) for u in users)
 
         if self.server_name in domains:
+            logger.info("Got typing update from %s: %r", user_id, content)
+            now = self.clock.time_msec()
+            self._member_typing_until[member] = now + FEDERATION_TIMEOUT
+            self.wheel_timer.insert(
+                now=now,
+                obj=member,
+                then=now + FEDERATION_TIMEOUT,
+            )
             self._push_update_local(
-                room_id=room_id,
-                user_id=user_id,
+                member=member,
                 typing=content["typing"]
             )
 
-    def _push_update_local(self, room_id, user_id, typing):
-        room_set = self._room_typing.setdefault(room_id, set())
+    def _push_update_local(self, member, typing):
+        room_set = self._room_typing.setdefault(member.room_id, set())
         if typing:
-            room_set.add(user_id)
+            room_set.add(member.user_id)
         else:
-            room_set.discard(user_id)
+            room_set.discard(member.user_id)
 
         self._latest_room_serial += 1
-        self._room_serials[room_id] = self._latest_room_serial
+        self._room_serials[member.room_id] = self._latest_room_serial
 
-        with PreserveLoggingContext():
-            self.notifier.on_new_event(
-                "typing_key", self._latest_room_serial, rooms=[room_id]
-            )
+        self.notifier.on_new_event(
+            "typing_key", self._latest_room_serial, rooms=[member.room_id]
+        )
 
     def get_all_typing_updates(self, last_id, current_id):
         # TODO: Work out a way to do this without scanning the entire state.