summary refs log tree commit diff
path: root/synapse/handlers/typing.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/typing.py')
-rw-r--r--synapse/handlers/typing.py307
1 files changed, 210 insertions, 97 deletions
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index c7bc14c623..a86ac0150e 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,17 +15,19 @@
 
 import logging
 from collections import namedtuple
-from typing import List
-
-from twisted.internet import defer
+from typing import TYPE_CHECKING, List, Set, Tuple
 
 from synapse.api.errors import AuthError, SynapseError
-from synapse.logging.context import run_in_background
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.tcp.streams import TypingStream
 from synapse.types import UserID, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -41,48 +43,48 @@ FEDERATION_TIMEOUT = 60 * 1000
 FEDERATION_PING_INTERVAL = 40 * 1000
 
 
-class TypingHandler(object):
-    def __init__(self, hs):
+class FollowerTypingHandler:
+    """A typing handler on a different process than the writer that is updated
+    via replication.
+    """
+
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.server_name = hs.config.server_name
-        self.auth = hs.get_auth()
-        self.is_mine_id = hs.is_mine_id
-        self.notifier = hs.get_notifier()
-        self.state = hs.get_state_handler()
-
-        self.hs = hs
-
         self.clock = hs.get_clock()
-        self.wheel_timer = WheelTimer(bucket_size=5000)
+        self.is_mine_id = hs.is_mine_id
 
-        self.federation = hs.get_federation_sender()
+        self.federation = None
+        if hs.should_send_federation():
+            self.federation = hs.get_federation_sender()
 
-        hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
+        if hs.config.worker.writers.typing != hs.get_instance_name():
+            hs.get_federation_registry().register_instance_for_edu(
+                "m.typing", hs.config.worker.writers.typing,
+            )
 
-        hs.get_distributor().observe("user_left_room", self.user_left_room)
+        # map room IDs to serial numbers
+        self._room_serials = {}
+        # map room IDs to sets of users currently typing
+        self._room_typing = {}
 
-        self._member_typing_until = {}  # clock time we expect to stop
         self._member_last_federation_poke = {}
-
+        self.wheel_timer = WheelTimer(bucket_size=5000)
         self._latest_room_serial = 0
-        self._reset()
-
-        # caches which room_ids changed at which serials
-        self._typing_stream_change_cache = StreamChangeCache(
-            "TypingStreamChangeCache", self._latest_room_serial
-        )
 
         self.clock.looping_call(self._handle_timeouts, 5000)
 
     def _reset(self):
-        """
-        Reset the typing handler's data caches.
+        """Reset the typing handler's data caches.
         """
         # map room IDs to serial numbers
         self._room_serials = {}
         # map room IDs to sets of users currently typing
         self._room_typing = {}
 
+        self._member_last_federation_poke = {}
+        self.wheel_timer = WheelTimer(bucket_size=5000)
+
     def _handle_timeouts(self):
         logger.debug("Checking for typing timeouts")
 
@@ -91,32 +93,141 @@ class TypingHandler(object):
         members = set(self.wheel_timer.fetch(now))
 
         for member in members:
-            if not self.is_typing(member):
-                # Nothing to do if they're no longer typing
-                continue
-
-            until = self._member_typing_until.get(member, None)
-            if not until or until <= now:
-                logger.info("Timing out typing for: %s", member.user_id)
-                self._stopped_typing(member)
-                continue
-
-            # Check if we need to resend a keep alive over federation for this
-            # user.
-            if self.hs.is_mine_id(member.user_id):
-                last_fed_poke = self._member_last_federation_poke.get(member, None)
-                if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
-                    run_in_background(self._push_remote, member=member, typing=True)
-
-            # Add a paranoia timer to ensure that we always have a timer for
-            # each person typing.
-            self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
+            self._handle_timeout_for_member(now, member)
+
+    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+        if not self.is_typing(member):
+            # Nothing to do if they're no longer typing
+            return
+
+        # Check if we need to resend a keep alive over federation for this
+        # user.
+        if self.federation and self.is_mine_id(member.user_id):
+            last_fed_poke = self._member_last_federation_poke.get(member, None)
+            if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
+                run_as_background_process(
+                    "typing._push_remote", self._push_remote, member=member, typing=True
+                )
+
+        # Add a paranoia timer to ensure that we always have a timer for
+        # each person typing.
+        self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
 
     def is_typing(self, member):
         return member.user_id in self._room_typing.get(member.room_id, [])
 
-    @defer.inlineCallbacks
-    def started_typing(self, target_user, auth_user, room_id, timeout):
+    async def _push_remote(self, member, typing):
+        if not self.federation:
+            return
+
+        try:
+            users = await self.store.get_users_in_room(member.room_id)
+            self._member_last_federation_poke[member] = self.clock.time_msec()
+
+            now = self.clock.time_msec()
+            self.wheel_timer.insert(
+                now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
+            )
+
+            for domain in {get_domain_from_id(u) for u in users}:
+                if domain != self.server_name:
+                    logger.debug("sending typing update to %s", domain)
+                    self.federation.build_and_send_edu(
+                        destination=domain,
+                        edu_type="m.typing",
+                        content={
+                            "room_id": member.room_id,
+                            "user_id": member.user_id,
+                            "typing": typing,
+                        },
+                        key=member,
+                    )
+        except Exception:
+            logger.exception("Error pushing typing notif to remotes")
+
+    def process_replication_rows(
+        self, token: int, rows: List[TypingStream.TypingStreamRow]
+    ):
+        """Should be called whenever we receive updates for typing stream.
+        """
+
+        if self._latest_room_serial > token:
+            # The master has gone backwards. To prevent inconsistent data, just
+            # clear everything.
+            self._reset()
+
+        # Set the latest serial token to whatever the server gave us.
+        self._latest_room_serial = token
+
+        for row in rows:
+            self._room_serials[row.room_id] = token
+
+            prev_typing = set(self._room_typing.get(row.room_id, []))
+            now_typing = set(row.user_ids)
+            self._room_typing[row.room_id] = row.user_ids
+
+            run_as_background_process(
+                "_handle_change_in_typing",
+                self._handle_change_in_typing,
+                row.room_id,
+                prev_typing,
+                now_typing,
+            )
+
+    async def _handle_change_in_typing(
+        self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
+    ):
+        """Process a change in typing of a room from replication, sending EDUs
+        for any local users.
+        """
+        for user_id in now_typing - prev_typing:
+            if self.is_mine_id(user_id):
+                await self._push_remote(RoomMember(room_id, user_id), True)
+
+        for user_id in prev_typing - now_typing:
+            if self.is_mine_id(user_id):
+                await self._push_remote(RoomMember(room_id, user_id), False)
+
+    def get_current_token(self):
+        return self._latest_room_serial
+
+
+class TypingWriterHandler(FollowerTypingHandler):
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        assert hs.config.worker.writers.typing == hs.get_instance_name()
+
+        self.auth = hs.get_auth()
+        self.notifier = hs.get_notifier()
+
+        self.hs = hs
+
+        hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
+
+        hs.get_distributor().observe("user_left_room", self.user_left_room)
+
+        self._member_typing_until = {}  # clock time we expect to stop
+
+        # caches which room_ids changed at which serials
+        self._typing_stream_change_cache = StreamChangeCache(
+            "TypingStreamChangeCache", self._latest_room_serial
+        )
+
+    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+        super()._handle_timeout_for_member(now, member)
+
+        if not self.is_typing(member):
+            # Nothing to do if they're no longer typing
+            return
+
+        until = self._member_typing_until.get(member, None)
+        if not until or until <= now:
+            logger.info("Timing out typing for: %s", member.user_id)
+            self._stopped_typing(member)
+            return
+
+    async def started_typing(self, target_user, auth_user, room_id, timeout):
         target_user_id = target_user.to_string()
         auth_user_id = auth_user.to_string()
 
@@ -126,7 +237,7 @@ class TypingHandler(object):
         if target_user_id != auth_user_id:
             raise AuthError(400, "Cannot set another user's typing state")
 
-        yield self.auth.check_user_in_room(room_id, target_user_id)
+        await self.auth.check_user_in_room(room_id, target_user_id)
 
         logger.debug("%s has started typing in %s", target_user_id, room_id)
 
@@ -145,8 +256,7 @@ class TypingHandler(object):
 
         self._push_update(member=member, typing=True)
 
-    @defer.inlineCallbacks
-    def stopped_typing(self, target_user, auth_user, room_id):
+    async def stopped_typing(self, target_user, auth_user, room_id):
         target_user_id = target_user.to_string()
         auth_user_id = auth_user.to_string()
 
@@ -156,7 +266,7 @@ class TypingHandler(object):
         if target_user_id != auth_user_id:
             raise AuthError(400, "Cannot set another user's typing state")
 
-        yield self.auth.check_user_in_room(room_id, target_user_id)
+        await self.auth.check_user_in_room(room_id, target_user_id)
 
         logger.debug("%s has stopped typing in %s", target_user_id, room_id)
 
@@ -164,12 +274,11 @@ class TypingHandler(object):
 
         self._stopped_typing(member)
 
-    @defer.inlineCallbacks
     def user_left_room(self, user, room_id):
         user_id = user.to_string()
         if self.is_mine_id(user_id):
             member = RoomMember(room_id=room_id, user_id=user_id)
-            yield self._stopped_typing(member)
+            self._stopped_typing(member)
 
     def _stopped_typing(self, member):
         if member.user_id not in self._room_typing.get(member.room_id, set()):
@@ -184,39 +293,13 @@ class TypingHandler(object):
     def _push_update(self, member, typing):
         if self.hs.is_mine_id(member.user_id):
             # Only send updates for changes to our own users.
-            run_in_background(self._push_remote, member, typing)
-
-        self._push_update_local(member=member, typing=typing)
-
-    @defer.inlineCallbacks
-    def _push_remote(self, member, typing):
-        try:
-            users = yield self.state.get_current_users_in_room(member.room_id)
-            self._member_last_federation_poke[member] = self.clock.time_msec()
-
-            now = self.clock.time_msec()
-            self.wheel_timer.insert(
-                now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
+            run_as_background_process(
+                "typing._push_remote", self._push_remote, member, typing
             )
 
-            for domain in {get_domain_from_id(u) for u in users}:
-                if domain != self.server_name:
-                    logger.debug("sending typing update to %s", domain)
-                    self.federation.build_and_send_edu(
-                        destination=domain,
-                        edu_type="m.typing",
-                        content={
-                            "room_id": member.room_id,
-                            "user_id": member.user_id,
-                            "typing": typing,
-                        },
-                        key=member,
-                    )
-        except Exception:
-            logger.exception("Error pushing typing notif to remotes")
+        self._push_update_local(member=member, typing=typing)
 
-    @defer.inlineCallbacks
-    def _recv_edu(self, origin, content):
+    async def _recv_edu(self, origin, content):
         room_id = content["room_id"]
         user_id = content["user_id"]
 
@@ -231,7 +314,7 @@ class TypingHandler(object):
             )
             return
 
-        users = yield self.state.get_current_users_in_room(room_id)
+        users = await self.store.get_users_in_room(room_id)
         domains = {get_domain_from_id(u) for u in users}
 
         if self.server_name in domains:
@@ -259,14 +342,31 @@ class TypingHandler(object):
         )
 
     async def get_all_typing_updates(
-        self, last_id: int, current_id: int, limit: int
-    ) -> List[dict]:
-        """Get up to `limit` typing updates between the given tokens, earliest
-        updates first.
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, list]], int, bool]:
+        """Get updates for typing replication stream.
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
 
         if last_id == current_id:
-            return []
+            return [], current_id, False
 
         changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
             last_id
@@ -280,12 +380,25 @@ class TypingHandler(object):
             serial = self._room_serials[room_id]
             if last_id < serial <= current_id:
                 typing = self._room_typing[room_id]
-                rows.append((serial, room_id, list(typing)))
+                rows.append((serial, [room_id, list(typing)]))
         rows.sort()
-        return rows[:limit]
 
-    def get_current_token(self):
-        return self._latest_room_serial
+        limited = False
+        # We, unusually, use a strict limit here as we have all the rows in
+        # memory rather than pulling them out of the database with a `LIMIT ?`
+        # clause.
+        if len(rows) > limit:
+            rows = rows[:limit]
+            current_id = rows[-1][0]
+            limited = True
+
+        return rows, current_id, limited
+
+    def process_replication_rows(
+        self, token: int, rows: List[TypingStream.TypingStreamRow]
+    ):
+        # The writing process should never get updates from replication.
+        raise Exception("Typing writer instance got typing info over replication")
 
 
 class TypingNotificationEventSource(object):
@@ -306,7 +419,7 @@ class TypingNotificationEventSource(object):
             "content": {"user_ids": list(typing)},
         }
 
-    def get_new_events(self, from_key, room_ids, **kwargs):
+    async def get_new_events(self, from_key, room_ids, **kwargs):
         with Measure(self.clock, "typing.get_new_events"):
             from_key = int(from_key)
             handler = self.get_typing_handler()
@@ -320,7 +433,7 @@ class TypingNotificationEventSource(object):
 
                 events.append(self._make_event_for(room_id))
 
-            return defer.succeed((events, handler._latest_room_serial))
+            return (events, handler._latest_room_serial)
 
     def get_current_key(self):
         return self.get_typing_handler()._latest_room_serial