summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/federation/federation_client.py2
-rw-r--r--synapse/handlers/typing.py175
-rw-r--r--synapse/rest/client/v1/room.py5
-rw-r--r--tests/handlers/test_typing.py7
-rw-r--r--tests/rest/client/v1/test_typing.py5
-rw-r--r--tests/utils.py9
6 files changed, 120 insertions, 83 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 06d0320b1a..94e76b1978 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -136,9 +136,7 @@ class FederationClient(FederationBase):
 
         sent_edus_counter.inc()
 
-        # TODO, add errback, etc.
         self._transaction_queue.enqueue_edu(edu, key=key)
-        return defer.succeed(None)
 
     @log_function
     def send_device_messages(self, destination):
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 0548b81c34..505a68d142 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -16,10 +16,9 @@
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, AuthError
-from synapse.util.logcontext import (
-    PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
-)
+from synapse.util.logcontext import preserve_fn
 from synapse.util.metrics import Measure
+from synapse.util.wheel_timer import WheelTimer
 from synapse.types import UserID, get_domain_from_id
 
 import logging
@@ -35,6 +34,13 @@ logger = logging.getLogger(__name__)
 RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
 
 
+# How often we expect remote servers to resend us presence.
+FEDERATION_TIMEOUT = 60 * 1000
+
+# How often to resend typing across federation.
+FEDERATION_PING_INTERVAL = 40 * 1000
+
+
 class TypingHandler(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
@@ -44,7 +50,10 @@ class TypingHandler(object):
         self.notifier = hs.get_notifier()
         self.state = hs.get_state_handler()
 
+        self.hs = hs
+
         self.clock = hs.get_clock()
+        self.wheel_timer = WheelTimer()
 
         self.federation = hs.get_replication_layer()
 
@@ -53,7 +62,7 @@ class TypingHandler(object):
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
         self._member_typing_until = {}  # clock time we expect to stop
-        self._member_typing_timer = {}  # deferreds to manage theabove
+        self._member_last_federation_poke = {}
 
         # map room IDs to serial numbers
         self._room_serials = {}
@@ -61,12 +70,41 @@ class TypingHandler(object):
         # map room IDs to sets of users currently typing
         self._room_typing = {}
 
-    def tearDown(self):
-        """Cancels all the pending timers.
-        Normally this shouldn't be needed, but it's required from unit tests
-        to avoid a "Reactor was unclean" warning."""
-        for t in self._member_typing_timer.values():
-            self.clock.cancel_call_later(t)
+        self.clock.looping_call(
+            self._handle_timeouts,
+            5000,
+        )
+
+    def _handle_timeouts(self):
+        logger.info("Handling typing timeout")
+
+        now = self.clock.time_msec()
+
+        members = set(self.wheel_timer.fetch(now))
+
+        for member in members:
+            if not self.is_typing(member):
+                # Nothing to do if they're no longer typing
+                continue
+
+            until = self._member_typing_until.get(member, None)
+            if not until or until < now:
+                logger.info("Timing out typing for: %s", member.user_id)
+                preserve_fn(self._stopped_typing)(member)
+                continue
+
+            # Check if we need to resend a keep alive over federation for this
+            # user.
+            if self.hs.is_mine_id(member.user_id):
+                last_fed_poke = self._member_last_federation_poke.get(member, None)
+                if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL < now:
+                    preserve_fn(self._push_remote)(
+                        member=member,
+                        typing=True
+                    )
+
+    def is_typing(self, member):
+        return member.user_id in self._room_typing.get(member.room_id, [])
 
     @defer.inlineCallbacks
     def started_typing(self, target_user, auth_user, room_id, timeout):
@@ -85,23 +123,23 @@ class TypingHandler(object):
             "%s has started typing in %s", target_user_id, room_id
         )
 
-        until = self.clock.time_msec() + timeout
         member = RoomMember(room_id=room_id, user_id=target_user_id)
 
-        was_present = member in self._member_typing_until
+        was_present = member.user_id in self._room_typing.get(room_id, set())
 
-        if member in self._member_typing_timer:
-            self.clock.cancel_call_later(self._member_typing_timer[member])
+        now = self.clock.time_msec()
+        self._member_typing_until[member] = now + timeout
 
-        def _cb():
-            logger.debug(
-                "%s has timed out in %s", target_user.to_string(), room_id
-            )
-            self._stopped_typing(member)
+        self.wheel_timer.insert(
+            now=now,
+            obj=member,
+            then=now + timeout,
+        )
 
-        self._member_typing_until[member] = until
-        self._member_typing_timer[member] = self.clock.call_later(
-            timeout / 1000.0, _cb
+        self.wheel_timer.insert(
+            now=now,
+            obj=member,
+            then=now + FEDERATION_PING_INTERVAL,
         )
 
         if was_present:
@@ -109,8 +147,7 @@ class TypingHandler(object):
             defer.returnValue(None)
 
         yield self._push_update(
-            room_id=room_id,
-            user_id=target_user_id,
+            member=member,
             typing=True,
         )
 
@@ -133,10 +170,6 @@ class TypingHandler(object):
 
         member = RoomMember(room_id=room_id, user_id=target_user_id)
 
-        if member in self._member_typing_timer:
-            self.clock.cancel_call_later(self._member_typing_timer[member])
-            del self._member_typing_timer[member]
-
         yield self._stopped_typing(member)
 
     @defer.inlineCallbacks
@@ -148,57 +181,53 @@ class TypingHandler(object):
 
     @defer.inlineCallbacks
     def _stopped_typing(self, member):
-        if member not in self._member_typing_until:
+        if member.user_id not in self._room_typing.get(member.room_id, set()):
             # No point
             defer.returnValue(None)
 
+        self._member_typing_until.pop(member, None)
+        self._member_last_federation_poke.pop(member, None)
+
         yield self._push_update(
-            room_id=member.room_id,
-            user_id=member.user_id,
+            member=member,
             typing=False,
         )
 
-        del self._member_typing_until[member]
-
-        if member in self._member_typing_timer:
-            # Don't cancel it - either it already expired, or the real
-            # stopped_typing() will cancel it
-            del self._member_typing_timer[member]
-
     @defer.inlineCallbacks
-    def _push_update(self, room_id, user_id, typing):
-        users = yield self.state.get_current_user_in_room(room_id)
-        domains = set(get_domain_from_id(u) for u in users)
+    def _push_update(self, member, typing):
+        if self.hs.is_mine_id(member.user_id):
+            # Only send updates for changes to our own users.
+            yield self._push_remote(member, typing)
+
+        self._push_update_local(
+            member=member,
+            typing=typing
+        )
 
-        deferreds = []
-        for domain in domains:
-            if domain == self.server_name:
-                preserve_fn(self._push_update_local)(
-                    room_id=room_id,
-                    user_id=user_id,
-                    typing=typing
-                )
-            else:
-                deferreds.append(preserve_fn(self.federation.send_edu)(
+    @defer.inlineCallbacks
+    def _push_remote(self, member, typing):
+        users = yield self.state.get_current_user_in_room(member.room_id)
+        self._member_last_federation_poke[member] = self.clock.time_msec()
+        for domain in set(get_domain_from_id(u) for u in users):
+            if domain != self.server_name:
+                self.federation.send_edu(
                     destination=domain,
                     edu_type="m.typing",
                     content={
-                        "room_id": room_id,
-                        "user_id": user_id,
+                        "room_id": member.room_id,
+                        "user_id": member.user_id,
                         "typing": typing,
                     },
-                    key=(room_id, user_id),
-                ))
-
-        yield preserve_context_over_deferred(
-            defer.DeferredList(deferreds, consumeErrors=True)
-        )
+                    key=member,
+                )
 
     @defer.inlineCallbacks
     def _recv_edu(self, origin, content):
         room_id = content["room_id"]
         user_id = content["user_id"]
 
+        member = RoomMember(user_id=user_id, room_id=room_id)
+
         # Check that the string is a valid user id
         user = UserID.from_string(user_id)
 
@@ -213,26 +242,32 @@ class TypingHandler(object):
         domains = set(get_domain_from_id(u) for u in users)
 
         if self.server_name in domains:
+            logger.info("Got typing update from %s: %r", user_id, content)
+            now = self.clock.time_msec()
+            self._member_typing_until[member] = now + FEDERATION_TIMEOUT
+            self.wheel_timer.insert(
+                now=now,
+                obj=member,
+                then=now + FEDERATION_TIMEOUT,
+            )
             self._push_update_local(
-                room_id=room_id,
-                user_id=user_id,
+                member=member,
                 typing=content["typing"]
             )
 
-    def _push_update_local(self, room_id, user_id, typing):
-        room_set = self._room_typing.setdefault(room_id, set())
+    def _push_update_local(self, member, typing):
+        room_set = self._room_typing.setdefault(member.room_id, set())
         if typing:
-            room_set.add(user_id)
+            room_set.add(member.user_id)
         else:
-            room_set.discard(user_id)
+            room_set.discard(member.user_id)
 
         self._latest_room_serial += 1
-        self._room_serials[room_id] = self._latest_room_serial
+        self._room_serials[member.room_id] = self._latest_room_serial
 
-        with PreserveLoggingContext():
-            self.notifier.on_new_event(
-                "typing_key", self._latest_room_serial, rooms=[room_id]
-            )
+        self.notifier.on_new_event(
+            "typing_key", self._latest_room_serial, rooms=[member.room_id]
+        )
 
     def get_all_typing_updates(self, last_id, current_id):
         # TODO: Work out a way to do this without scanning the entire state.
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 20889e4af0..010fbc7c32 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -705,12 +705,15 @@ class RoomTypingRestServlet(ClientV1RestServlet):
 
         yield self.presence_handler.bump_presence_active_time(requester.user)
 
+        # Limit timeout to stop people from setting silly typing timeouts.
+        timeout = min(content.get("timeout", 30000), 120000)
+
         if content["typing"]:
             yield self.typing_handler.started_typing(
                 target_user=target_user,
                 auth_user=requester.user,
                 room_id=room_id,
-                timeout=content.get("timeout", 30000),
+                timeout=timeout,
             )
         else:
             yield self.typing_handler.stopped_typing(
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index ea1f0f7c33..c3108f5181 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -267,10 +267,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
         from synapse.handlers.typing import RoomMember
         member = RoomMember(self.room_id, self.u_apple.to_string())
         self.handler._member_typing_until[member] = 1002000
-        self.handler._member_typing_timer[member] = (
-            self.clock.call_later(1002, lambda: 0)
-        )
-        self.handler._room_typing[self.room_id] = set((self.u_apple.to_string(),))
+        self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()])
 
         self.assertEquals(self.event_source.get_current_key(), 0)
 
@@ -330,7 +327,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
             },
         }])
 
-        self.clock.advance_time(11)
+        self.clock.advance_time(16)
 
         self.on_new_event.assert_has_calls([
             call('typing_key', 2, rooms=[self.room_id]),
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 467f253ef6..a269e6f56e 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -105,9 +105,6 @@ class RoomTypingTestCase(RestTestCase):
         # Need another user to make notifications actually work
         yield self.join(self.room_id, user="@jim:red")
 
-    def tearDown(self):
-        self.hs.get_typing_handler().tearDown()
-
     @defer.inlineCallbacks
     def test_set_typing(self):
         (code, _) = yield self.mock_resource.trigger(
@@ -147,7 +144,7 @@ class RoomTypingTestCase(RestTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 1)
 
-        self.clock.advance_time(31)
+        self.clock.advance_time(36)
 
         self.assertEquals(self.event_source.get_current_key(), 2)
 
diff --git a/tests/utils.py b/tests/utils.py
index 915b934e94..92d470cb48 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -220,6 +220,7 @@ class MockClock(object):
         # list of lists of [absolute_time, callback, expired] in no particular
         # order
         self.timers = []
+        self.loopers = []
 
     def time(self):
         return self.now
@@ -240,7 +241,7 @@ class MockClock(object):
         return t
 
     def looping_call(self, function, interval):
-        pass
+        self.loopers.append([function, interval / 1000., self.now])
 
     def cancel_call_later(self, timer, ignore_errs=False):
         if timer[2]:
@@ -269,6 +270,12 @@ class MockClock(object):
             else:
                 self.timers.append(t)
 
+        for looped in self.loopers:
+            func, interval, last = looped
+            if last + interval < self.now:
+                func()
+                looped[2] = self.now
+
     def advance_time_msec(self, ms):
         self.advance_time(ms / 1000.)