summary refs log tree commit diff
path: root/synapse/handlers/typing.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/typing.py')
-rw-r--r--synapse/handlers/typing.py241
1 files changed, 167 insertions, 74 deletions
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 846ddbdc6c..a86ac0150e 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,15 +15,19 @@
 
 import logging
 from collections import namedtuple
-from typing import List, Tuple
+from typing import TYPE_CHECKING, List, Set, Tuple
 
 from synapse.api.errors import AuthError, SynapseError
-from synapse.logging.context import run_in_background
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.tcp.streams import TypingStream
 from synapse.types import UserID, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -39,48 +43,48 @@ FEDERATION_TIMEOUT = 60 * 1000
 FEDERATION_PING_INTERVAL = 40 * 1000
 
 
-class TypingHandler(object):
-    def __init__(self, hs):
+class FollowerTypingHandler:
+    """A typing handler on a different process than the writer that is updated
+    via replication.
+    """
+
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.server_name = hs.config.server_name
-        self.auth = hs.get_auth()
-        self.is_mine_id = hs.is_mine_id
-        self.notifier = hs.get_notifier()
-        self.state = hs.get_state_handler()
-
-        self.hs = hs
-
         self.clock = hs.get_clock()
-        self.wheel_timer = WheelTimer(bucket_size=5000)
+        self.is_mine_id = hs.is_mine_id
 
-        self.federation = hs.get_federation_sender()
+        self.federation = None
+        if hs.should_send_federation():
+            self.federation = hs.get_federation_sender()
 
-        hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
+        if hs.config.worker.writers.typing != hs.get_instance_name():
+            hs.get_federation_registry().register_instance_for_edu(
+                "m.typing", hs.config.worker.writers.typing,
+            )
 
-        hs.get_distributor().observe("user_left_room", self.user_left_room)
+        # map room IDs to serial numbers
+        self._room_serials = {}
+        # map room IDs to sets of users currently typing
+        self._room_typing = {}
 
-        self._member_typing_until = {}  # clock time we expect to stop
         self._member_last_federation_poke = {}
-
+        self.wheel_timer = WheelTimer(bucket_size=5000)
         self._latest_room_serial = 0
-        self._reset()
-
-        # caches which room_ids changed at which serials
-        self._typing_stream_change_cache = StreamChangeCache(
-            "TypingStreamChangeCache", self._latest_room_serial
-        )
 
         self.clock.looping_call(self._handle_timeouts, 5000)
 
     def _reset(self):
-        """
-        Reset the typing handler's data caches.
+        """Reset the typing handler's data caches.
         """
         # map room IDs to serial numbers
         self._room_serials = {}
         # map room IDs to sets of users currently typing
         self._room_typing = {}
 
+        self._member_last_federation_poke = {}
+        self.wheel_timer = WheelTimer(bucket_size=5000)
+
     def _handle_timeouts(self):
         logger.debug("Checking for typing timeouts")
 
@@ -89,30 +93,140 @@ class TypingHandler(object):
         members = set(self.wheel_timer.fetch(now))
 
         for member in members:
-            if not self.is_typing(member):
-                # Nothing to do if they're no longer typing
-                continue
-
-            until = self._member_typing_until.get(member, None)
-            if not until or until <= now:
-                logger.info("Timing out typing for: %s", member.user_id)
-                self._stopped_typing(member)
-                continue
-
-            # Check if we need to resend a keep alive over federation for this
-            # user.
-            if self.hs.is_mine_id(member.user_id):
-                last_fed_poke = self._member_last_federation_poke.get(member, None)
-                if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
-                    run_in_background(self._push_remote, member=member, typing=True)
-
-            # Add a paranoia timer to ensure that we always have a timer for
-            # each person typing.
-            self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
+            self._handle_timeout_for_member(now, member)
+
+    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+        if not self.is_typing(member):
+            # Nothing to do if they're no longer typing
+            return
+
+        # Check if we need to resend a keep alive over federation for this
+        # user.
+        if self.federation and self.is_mine_id(member.user_id):
+            last_fed_poke = self._member_last_federation_poke.get(member, None)
+            if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
+                run_as_background_process(
+                    "typing._push_remote", self._push_remote, member=member, typing=True
+                )
+
+        # Add a paranoia timer to ensure that we always have a timer for
+        # each person typing.
+        self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
 
     def is_typing(self, member):
         return member.user_id in self._room_typing.get(member.room_id, [])
 
+    async def _push_remote(self, member, typing):
+        if not self.federation:
+            return
+
+        try:
+            users = await self.store.get_users_in_room(member.room_id)
+            self._member_last_federation_poke[member] = self.clock.time_msec()
+
+            now = self.clock.time_msec()
+            self.wheel_timer.insert(
+                now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
+            )
+
+            for domain in {get_domain_from_id(u) for u in users}:
+                if domain != self.server_name:
+                    logger.debug("sending typing update to %s", domain)
+                    self.federation.build_and_send_edu(
+                        destination=domain,
+                        edu_type="m.typing",
+                        content={
+                            "room_id": member.room_id,
+                            "user_id": member.user_id,
+                            "typing": typing,
+                        },
+                        key=member,
+                    )
+        except Exception:
+            logger.exception("Error pushing typing notif to remotes")
+
+    def process_replication_rows(
+        self, token: int, rows: List[TypingStream.TypingStreamRow]
+    ):
+        """Should be called whenever we receive updates for typing stream.
+        """
+
+        if self._latest_room_serial > token:
+            # The master has gone backwards. To prevent inconsistent data, just
+            # clear everything.
+            self._reset()
+
+        # Set the latest serial token to whatever the server gave us.
+        self._latest_room_serial = token
+
+        for row in rows:
+            self._room_serials[row.room_id] = token
+
+            prev_typing = set(self._room_typing.get(row.room_id, []))
+            now_typing = set(row.user_ids)
+            self._room_typing[row.room_id] = row.user_ids
+
+            run_as_background_process(
+                "_handle_change_in_typing",
+                self._handle_change_in_typing,
+                row.room_id,
+                prev_typing,
+                now_typing,
+            )
+
+    async def _handle_change_in_typing(
+        self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
+    ):
+        """Process a change in typing of a room from replication, sending EDUs
+        for any local users.
+        """
+        for user_id in now_typing - prev_typing:
+            if self.is_mine_id(user_id):
+                await self._push_remote(RoomMember(room_id, user_id), True)
+
+        for user_id in prev_typing - now_typing:
+            if self.is_mine_id(user_id):
+                await self._push_remote(RoomMember(room_id, user_id), False)
+
+    def get_current_token(self):
+        return self._latest_room_serial
+
+
+class TypingWriterHandler(FollowerTypingHandler):
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        assert hs.config.worker.writers.typing == hs.get_instance_name()
+
+        self.auth = hs.get_auth()
+        self.notifier = hs.get_notifier()
+
+        self.hs = hs
+
+        hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
+
+        hs.get_distributor().observe("user_left_room", self.user_left_room)
+
+        self._member_typing_until = {}  # clock time we expect to stop
+
+        # caches which room_ids changed at which serials
+        self._typing_stream_change_cache = StreamChangeCache(
+            "TypingStreamChangeCache", self._latest_room_serial
+        )
+
+    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+        super()._handle_timeout_for_member(now, member)
+
+        if not self.is_typing(member):
+            # Nothing to do if they're no longer typing
+            return
+
+        until = self._member_typing_until.get(member, None)
+        if not until or until <= now:
+            logger.info("Timing out typing for: %s", member.user_id)
+            self._stopped_typing(member)
+            return
+
     async def started_typing(self, target_user, auth_user, room_id, timeout):
         target_user_id = target_user.to_string()
         auth_user_id = auth_user.to_string()
@@ -179,35 +293,11 @@ class TypingHandler(object):
     def _push_update(self, member, typing):
         if self.hs.is_mine_id(member.user_id):
             # Only send updates for changes to our own users.
-            run_in_background(self._push_remote, member, typing)
-
-        self._push_update_local(member=member, typing=typing)
-
-    async def _push_remote(self, member, typing):
-        try:
-            users = await self.store.get_users_in_room(member.room_id)
-            self._member_last_federation_poke[member] = self.clock.time_msec()
-
-            now = self.clock.time_msec()
-            self.wheel_timer.insert(
-                now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
+            run_as_background_process(
+                "typing._push_remote", self._push_remote, member, typing
             )
 
-            for domain in {get_domain_from_id(u) for u in users}:
-                if domain != self.server_name:
-                    logger.debug("sending typing update to %s", domain)
-                    self.federation.build_and_send_edu(
-                        destination=domain,
-                        edu_type="m.typing",
-                        content={
-                            "room_id": member.room_id,
-                            "user_id": member.user_id,
-                            "typing": typing,
-                        },
-                        key=member,
-                    )
-        except Exception:
-            logger.exception("Error pushing typing notif to remotes")
+        self._push_update_local(member=member, typing=typing)
 
     async def _recv_edu(self, origin, content):
         room_id = content["room_id"]
@@ -304,8 +394,11 @@ class TypingHandler(object):
 
         return rows, current_id, limited
 
-    def get_current_token(self):
-        return self._latest_room_serial
+    def process_replication_rows(
+        self, token: int, rows: List[TypingStream.TypingStreamRow]
+    ):
+        # The writing process should never get updates from replication.
+        raise Exception("Typing writer instance got typing info over replication")
 
 
 class TypingNotificationEventSource(object):