summary refs log tree commit diff
path: root/synapse/replication/tcp/streams/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/streams/_base.py')
-rw-r--r--synapse/replication/tcp/streams/_base.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f45e7a8c89..7e8e64d61c 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -33,7 +33,7 @@ import attr
 from synapse.replication.http.streams import ReplicationGetStreamUpdates
 
 if TYPE_CHECKING:
-    import synapse.server
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -299,20 +299,23 @@ class TypingStream(Stream):
     NAME = "typing"
     ROW_TYPE = TypingStreamRow
 
-    def __init__(self, hs):
-        typing_handler = hs.get_typing_handler()
-
+    def __init__(self, hs: "HomeServer"):
         writer_instance = hs.config.worker.writers.typing
         if writer_instance == hs.get_instance_name():
             # On the writer, query the typing handler
-            update_function = typing_handler.get_all_typing_updates
+            typing_writer_handler = hs.get_typing_writer_handler()
+            update_function = (
+                typing_writer_handler.get_all_typing_updates
+            )  # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]]
+            current_token_function = typing_writer_handler.get_current_token
         else:
             # Query the typing writer process
             update_function = make_http_update_function(hs, self.NAME)
+            current_token_function = hs.get_typing_handler().get_current_token
 
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(typing_handler.get_current_token),
+            current_token_without_instance(current_token_function),
             update_function,
         )
 
@@ -509,7 +512,7 @@ class AccountDataStream(Stream):
     NAME = "account_data"
     ROW_TYPE = AccountDataStreamRow
 
-    def __init__(self, hs: "synapse.server.HomeServer"):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),