summary refs log tree commit diff
path: root/synapse/handlers/typing.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/typing.py')
-rw-r--r--synapse/handlers/typing.py37
1 files changed, 30 insertions, 7 deletions
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 2d2d3d5a0d..a61bbf9392 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import AuthError, SynapseError
 from synapse.types import UserID, get_domain_from_id
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.logcontext import run_in_background
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -62,17 +63,28 @@ class TypingHandler(object):
         self._member_typing_until = {}  # clock time we expect to stop
         self._member_last_federation_poke = {}
 
-        # map room IDs to serial numbers
-        self._room_serials = {}
         self._latest_room_serial = 0
-        # map room IDs to sets of users currently typing
-        self._room_typing = {}
+        self._reset()
+
+        # caches which room_ids changed at which serials
+        self._typing_stream_change_cache = StreamChangeCache(
+            "TypingStreamChangeCache", self._latest_room_serial,
+        )
 
         self.clock.looping_call(
             self._handle_timeouts,
             5000,
         )
 
+    def _reset(self):
+        """
+        Reset the typing handler's data caches.
+        """
+        # map room IDs to serial numbers
+        self._room_serials = {}
+        # map room IDs to sets of users currently typing
+        self._room_typing = {}
+
     def _handle_timeouts(self):
         logger.info("Checking for typing timeouts")
 
@@ -218,6 +230,7 @@ class TypingHandler(object):
 
             for domain in set(get_domain_from_id(u) for u in users):
                 if domain != self.server_name:
+                    logger.debug("sending typing update to %s", domain)
                     self.federation.send_edu(
                         destination=domain,
                         edu_type="m.typing",
@@ -274,19 +287,29 @@ class TypingHandler(object):
 
         self._latest_room_serial += 1
         self._room_serials[member.room_id] = self._latest_room_serial
+        self._typing_stream_change_cache.entity_has_changed(
+            member.room_id, self._latest_room_serial,
+        )
 
         self.notifier.on_new_event(
             "typing_key", self._latest_room_serial, rooms=[member.room_id]
         )
 
     def get_all_typing_updates(self, last_id, current_id):
-        # TODO: Work out a way to do this without scanning the entire state.
         if last_id == current_id:
             return []
 
+        changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
+            last_id,
+        )
+
+        if changed_rooms is None:
+            changed_rooms = self._room_serials
+
         rows = []
-        for room_id, serial in self._room_serials.items():
-            if last_id < serial and serial <= current_id:
+        for room_id in changed_rooms:
+            serial = self._room_serials[room_id]
+            if last_id < serial <= current_id:
                 typing = self._room_typing[room_id]
                 rows.append((serial, room_id, list(typing)))
         rows.sort()