summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-03-17 11:30:21 -0400
committerGitHub <noreply@github.com>2021-03-17 11:30:21 -0400
commitcc324d53fe531d002aca28a9d8e5b85768cdef23 (patch)
tree449f90ba69043a2d732531ed3937183a7901cb6a /synapse
parentonly save remote cross-signing keys if they're different from the current one... (diff)
downloadsynapse-cc324d53fe531d002aca28a9d8e5b85768cdef23.tar.xz
Fix up types for the typing handler. (#9638)
By splitting this to two separate methods the callers know
what methods they can expect on the handler.
Diffstat (limited to '')
-rw-r--r--synapse/replication/tcp/streams/_base.py17
-rw-r--r--synapse/rest/client/v1/room.py15
-rw-r--r--synapse/server.py11
3 files changed, 29 insertions, 14 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f45e7a8c89..7e8e64d61c 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -33,7 +33,7 @@ import attr
 from synapse.replication.http.streams import ReplicationGetStreamUpdates
 
 if TYPE_CHECKING:
-    import synapse.server
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -299,20 +299,23 @@ class TypingStream(Stream):
     NAME = "typing"
     ROW_TYPE = TypingStreamRow
 
-    def __init__(self, hs):
-        typing_handler = hs.get_typing_handler()
-
+    def __init__(self, hs: "HomeServer"):
         writer_instance = hs.config.worker.writers.typing
         if writer_instance == hs.get_instance_name():
             # On the writer, query the typing handler
-            update_function = typing_handler.get_all_typing_updates
+            typing_writer_handler = hs.get_typing_writer_handler()
+            update_function = (
+                typing_writer_handler.get_all_typing_updates
+            )  # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]]
+            current_token_function = typing_writer_handler.get_current_token
         else:
             # Query the typing writer process
             update_function = make_http_update_function(hs, self.NAME)
+            current_token_function = hs.get_typing_handler().get_current_token
 
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(typing_handler.get_current_token),
+            current_token_without_instance(current_token_function),
             update_function,
         )
 
@@ -509,7 +512,7 @@ class AccountDataStream(Stream):
     NAME = "account_data"
     ROW_TYPE = AccountDataStreamRow
 
-    def __init__(self, hs: "synapse.server.HomeServer"):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 5884daea6d..e7a8207eb1 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -49,7 +49,7 @@ from synapse.util import json_decoder
 from synapse.util.stringutils import parse_and_validate_server_name, random_string
 
 if TYPE_CHECKING:
-    import synapse.server
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -846,10 +846,10 @@ class RoomTypingRestServlet(RestServlet):
         "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
+        self.hs = hs
         self.presence_handler = hs.get_presence_handler()
-        self.typing_handler = hs.get_typing_handler()
         self.auth = hs.get_auth()
 
         # If we're not on the typing writer instance we should scream if we get
@@ -874,16 +874,19 @@ class RoomTypingRestServlet(RestServlet):
         # Limit timeout to stop people from setting silly typing timeouts.
         timeout = min(content.get("timeout", 30000), 120000)
 
+        # Defer getting the typing handler since it will raise on workers.
+        typing_handler = self.hs.get_typing_writer_handler()
+
         try:
             if content["typing"]:
-                await self.typing_handler.started_typing(
+                await typing_handler.started_typing(
                     target_user=target_user,
                     requester=requester,
                     room_id=room_id,
                     timeout=timeout,
                 )
             else:
-                await self.typing_handler.stopped_typing(
+                await typing_handler.stopped_typing(
                     target_user=target_user, requester=requester, room_id=room_id
                 )
         except ShadowBanError:
@@ -901,7 +904,7 @@ class RoomAliasListServlet(RestServlet):
         ),
     ]
 
-    def __init__(self, hs: "synapse.server.HomeServer"):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.directory_handler = hs.get_directory_handler()
diff --git a/synapse/server.py b/synapse/server.py
index dd4ee7dd3c..d11d08c573 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -417,10 +417,19 @@ class HomeServer(metaclass=abc.ABCMeta):
         return PresenceHandler(self)
 
     @cache_in_self
-    def get_typing_handler(self):
+    def get_typing_writer_handler(self) -> TypingWriterHandler:
         if self.config.worker.writers.typing == self.get_instance_name():
             return TypingWriterHandler(self)
         else:
+            raise Exception("Workers cannot write typing")
+
+    @cache_in_self
+    def get_typing_handler(self) -> FollowerTypingHandler:
+        if self.config.worker.writers.typing == self.get_instance_name():
+            # Use get_typing_writer_handler to ensure that we use the same
+            # cached version.
+            return self.get_typing_writer_handler()
+        else:
             return FollowerTypingHandler(self)
 
     @cache_in_self