summary refs log tree commit diff
path: root/synapse/handlers/typing.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/typing.py')
-rw-r--r--synapse/handlers/typing.py23
1 files changed, 20 insertions, 3 deletions
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 2d2d3d5a0d..c610933dd4 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import AuthError, SynapseError
 from synapse.types import UserID, get_domain_from_id
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.logcontext import run_in_background
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -68,6 +69,11 @@ class TypingHandler(object):
         # map room IDs to sets of users currently typing
         self._room_typing = {}
 
+        # caches which room_ids changed at which serials
+        self._typing_stream_change_cache = StreamChangeCache(
+            "TypingStreamChangeCache", self._latest_room_serial,
+        )
+
         self.clock.looping_call(
             self._handle_timeouts,
             5000,
@@ -218,6 +224,7 @@ class TypingHandler(object):
 
             for domain in set(get_domain_from_id(u) for u in users):
                 if domain != self.server_name:
+                    logger.debug("sending typing update to %s", domain)
                     self.federation.send_edu(
                         destination=domain,
                         edu_type="m.typing",
@@ -274,19 +281,29 @@ class TypingHandler(object):
 
         self._latest_room_serial += 1
         self._room_serials[member.room_id] = self._latest_room_serial
+        self._typing_stream_change_cache.entity_has_changed(
+            member.room_id, self._latest_room_serial,
+        )
 
         self.notifier.on_new_event(
             "typing_key", self._latest_room_serial, rooms=[member.room_id]
         )
 
     def get_all_typing_updates(self, last_id, current_id):
-        # TODO: Work out a way to do this without scanning the entire state.
         if last_id == current_id:
             return []
 
+        changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
+            last_id,
+        )
+
+        if changed_rooms is None:
+            changed_rooms = self._room_serials
+
         rows = []
-        for room_id, serial in self._room_serials.items():
-            if last_id < serial and serial <= current_id:
+        for room_id in changed_rooms:
+            serial = self._room_serials[room_id]
+            if last_id < serial <= current_id:
                 typing = self._room_typing[room_id]
                 rows.append((serial, room_id, list(typing)))
         rows.sort()