summary refs log tree commit diff
path: root/synapse/handlers/typing.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/typing.py78
1 files changed, 42 insertions, 36 deletions
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e919a8f9ed..096d199f4c 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,13 +15,13 @@
 import logging
 import random
 from collections import namedtuple
-from typing import TYPE_CHECKING, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import AuthError, ShadowBanError, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.streams import TypingStream
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -61,23 +61,23 @@ class FollowerTypingHandler:
 
         if hs.config.worker.writers.typing != hs.get_instance_name():
             hs.get_federation_registry().register_instance_for_edu(
-                "m.typing", hs.config.worker.writers.typing,
+                "m.typing",
+                hs.config.worker.writers.typing,
             )
 
         # map room IDs to serial numbers
-        self._room_serials = {}
+        self._room_serials = {}  # type: Dict[str, int]
         # map room IDs to sets of users currently typing
-        self._room_typing = {}
+        self._room_typing = {}  # type: Dict[str, Set[str]]
 
-        self._member_last_federation_poke = {}
+        self._member_last_federation_poke = {}  # type: Dict[RoomMember, int]
         self.wheel_timer = WheelTimer(bucket_size=5000)
         self._latest_room_serial = 0
 
         self.clock.looping_call(self._handle_timeouts, 5000)
 
-    def _reset(self):
-        """Reset the typing handler's data caches.
-        """
+    def _reset(self) -> None:
+        """Reset the typing handler's data caches."""
         # map room IDs to serial numbers
         self._room_serials = {}
         # map room IDs to sets of users currently typing
@@ -86,7 +86,7 @@ class FollowerTypingHandler:
         self._member_last_federation_poke = {}
         self.wheel_timer = WheelTimer(bucket_size=5000)
 
-    def _handle_timeouts(self):
+    def _handle_timeouts(self) -> None:
         logger.debug("Checking for typing timeouts")
 
         now = self.clock.time_msec()
@@ -96,7 +96,7 @@ class FollowerTypingHandler:
         for member in members:
             self._handle_timeout_for_member(now, member)
 
-    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+    def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
         if not self.is_typing(member):
             # Nothing to do if they're no longer typing
             return
@@ -114,10 +114,10 @@ class FollowerTypingHandler:
         # each person typing.
         self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
 
-    def is_typing(self, member):
+    def is_typing(self, member: RoomMember) -> bool:
         return member.user_id in self._room_typing.get(member.room_id, [])
 
-    async def _push_remote(self, member, typing):
+    async def _push_remote(self, member: RoomMember, typing: bool) -> None:
         if not self.federation:
             return
 
@@ -148,9 +148,8 @@ class FollowerTypingHandler:
 
     def process_replication_rows(
         self, token: int, rows: List[TypingStream.TypingStreamRow]
-    ):
-        """Should be called whenever we receive updates for typing stream.
-        """
+    ) -> None:
+        """Should be called whenever we receive updates for typing stream."""
 
         if self._latest_room_serial > token:
             # The master has gone backwards. To prevent inconsistent data, just
@@ -178,7 +177,7 @@ class FollowerTypingHandler:
 
     async def _send_changes_in_typing_to_remotes(
         self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
-    ):
+    ) -> None:
         """Process a change in typing of a room from replication, sending EDUs
         for any local users.
         """
@@ -194,12 +193,12 @@ class FollowerTypingHandler:
             if self.is_mine_id(user_id):
                 await self._push_remote(RoomMember(room_id, user_id), False)
 
-    def get_current_token(self):
+    def get_current_token(self) -> int:
         return self._latest_room_serial
 
 
 class TypingWriterHandler(FollowerTypingHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         assert hs.config.worker.writers.typing == hs.get_instance_name()
@@ -213,14 +212,15 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
-        self._member_typing_until = {}  # clock time we expect to stop
+        # clock time we expect to stop
+        self._member_typing_until = {}  # type: Dict[RoomMember, int]
 
         # caches which room_ids changed at which serials
         self._typing_stream_change_cache = StreamChangeCache(
             "TypingStreamChangeCache", self._latest_room_serial
         )
 
-    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+    def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
         super()._handle_timeout_for_member(now, member)
 
         if not self.is_typing(member):
@@ -233,7 +233,9 @@ class TypingWriterHandler(FollowerTypingHandler):
             self._stopped_typing(member)
             return
 
-    async def started_typing(self, target_user, requester, room_id, timeout):
+    async def started_typing(
+        self, target_user: UserID, requester: Requester, room_id: str, timeout: int
+    ) -> None:
         target_user_id = target_user.to_string()
         auth_user_id = requester.user.to_string()
 
@@ -263,11 +265,13 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         if was_present:
             # No point sending another notification
-            return None
+            return
 
         self._push_update(member=member, typing=True)
 
-    async def stopped_typing(self, target_user, requester, room_id):
+    async def stopped_typing(
+        self, target_user: UserID, requester: Requester, room_id: str
+    ) -> None:
         target_user_id = target_user.to_string()
         auth_user_id = requester.user.to_string()
 
@@ -290,23 +294,23 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         self._stopped_typing(member)
 
-    def user_left_room(self, user, room_id):
+    def user_left_room(self, user: UserID, room_id: str) -> None:
         user_id = user.to_string()
         if self.is_mine_id(user_id):
             member = RoomMember(room_id=room_id, user_id=user_id)
             self._stopped_typing(member)
 
-    def _stopped_typing(self, member):
+    def _stopped_typing(self, member: RoomMember) -> None:
         if member.user_id not in self._room_typing.get(member.room_id, set()):
             # No point
-            return None
+            return
 
         self._member_typing_until.pop(member, None)
         self._member_last_federation_poke.pop(member, None)
 
         self._push_update(member=member, typing=False)
 
-    def _push_update(self, member, typing):
+    def _push_update(self, member: RoomMember, typing: bool) -> None:
         if self.hs.is_mine_id(member.user_id):
             # Only send updates for changes to our own users.
             run_as_background_process(
@@ -315,7 +319,7 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         self._push_update_local(member=member, typing=typing)
 
-    async def _recv_edu(self, origin, content):
+    async def _recv_edu(self, origin: str, content: JsonDict) -> None:
         room_id = content["room_id"]
         user_id = content["user_id"]
 
@@ -340,7 +344,7 @@ class TypingWriterHandler(FollowerTypingHandler):
             self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
             self._push_update_local(member=member, typing=content["typing"])
 
-    def _push_update_local(self, member, typing):
+    def _push_update_local(self, member: RoomMember, typing: bool) -> None:
         room_set = self._room_typing.setdefault(member.room_id, set())
         if typing:
             room_set.add(member.user_id)
@@ -386,7 +390,7 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
             last_id
-        )
+        )  # type: Optional[Iterable[str]]
 
         if changed_rooms is None:
             changed_rooms = self._room_serials
@@ -412,13 +416,13 @@ class TypingWriterHandler(FollowerTypingHandler):
 
     def process_replication_rows(
         self, token: int, rows: List[TypingStream.TypingStreamRow]
-    ):
+    ) -> None:
         # The writing process should never get updates from replication.
         raise Exception("Typing writer instance got typing info over replication")
 
 
 class TypingNotificationEventSource:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.clock = hs.get_clock()
         # We can't call get_typing_handler here because there's a cycle:
@@ -427,7 +431,7 @@ class TypingNotificationEventSource:
         #
         self.get_typing_handler = hs.get_typing_handler
 
-    def _make_event_for(self, room_id):
+    def _make_event_for(self, room_id: str) -> JsonDict:
         typing = self.get_typing_handler()._room_typing[room_id]
         return {
             "type": "m.typing",
@@ -462,7 +466,9 @@ class TypingNotificationEventSource:
 
             return (events, handler._latest_room_serial)
 
-    async def get_new_events(self, from_key, room_ids, **kwargs):
+    async def get_new_events(
+        self, from_key: int, room_ids: Iterable[str], **kwargs
+    ) -> Tuple[List[JsonDict], int]:
         with Measure(self.clock, "typing.get_new_events"):
             from_key = int(from_key)
             handler = self.get_typing_handler()
@@ -478,5 +484,5 @@ class TypingNotificationEventSource:
 
             return (events, handler._latest_room_serial)
 
-    def get_current_key(self):
+    def get_current_key(self) -> int:
         return self.get_typing_handler()._latest_room_serial