summary refs log tree commit diff
path: root/synapse/replication/tcp/commands.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/commands.py')
-rw-r--r--synapse/replication/tcp/commands.py74
1 files changed, 49 insertions, 25 deletions
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 1311b013da..3654f6c03c 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -18,12 +18,15 @@ allowed to be sent by which side.
 """
 import abc
 import logging
-from typing import Tuple, Type
+from typing import Optional, Tuple, Type, TypeVar
 
+from synapse.replication.tcp.streams._base import StreamRow
 from synapse.util import json_decoder, json_encoder
 
 logger = logging.getLogger(__name__)
 
+T = TypeVar("T", bound="Command")
+
 
 class Command(metaclass=abc.ABCMeta):
     """The base command class.
@@ -38,7 +41,7 @@ class Command(metaclass=abc.ABCMeta):
 
     @classmethod
     @abc.abstractmethod
-    def from_line(cls, line):
+    def from_line(cls: Type[T], line: str) -> T:
         """Deserialises a line from the wire into this command. `line` does not
         include the command.
         """
@@ -49,21 +52,24 @@ class Command(metaclass=abc.ABCMeta):
         prefix.
         """
 
-    def get_logcontext_id(self):
+    def get_logcontext_id(self) -> str:
         """Get a suitable string for the logcontext when processing this command"""
 
         # by default, we just use the command name.
         return self.NAME
 
 
+SC = TypeVar("SC", bound="_SimpleCommand")
+
+
 class _SimpleCommand(Command):
     """An implementation of Command whose argument is just a 'data' string."""
 
-    def __init__(self, data):
+    def __init__(self, data: str):
         self.data = data
 
     @classmethod
-    def from_line(cls, line):
+    def from_line(cls: Type[SC], line: str) -> SC:
         return cls(line)
 
     def to_line(self) -> str:
@@ -109,14 +115,16 @@ class RdataCommand(Command):
 
     NAME = "RDATA"
 
-    def __init__(self, stream_name, instance_name, token, row):
+    def __init__(
+        self, stream_name: str, instance_name: str, token: Optional[int], row: StreamRow
+    ):
         self.stream_name = stream_name
         self.instance_name = instance_name
         self.token = token
         self.row = row
 
     @classmethod
-    def from_line(cls, line):
+    def from_line(cls: Type["RdataCommand"], line: str) -> "RdataCommand":
         stream_name, instance_name, token, row_json = line.split(" ", 3)
         return cls(
             stream_name,
@@ -125,7 +133,7 @@ class RdataCommand(Command):
             json_decoder.decode(row_json),
         )
 
-    def to_line(self):
+    def to_line(self) -> str:
         return " ".join(
             (
                 self.stream_name,
@@ -135,7 +143,7 @@ class RdataCommand(Command):
             )
         )
 
-    def get_logcontext_id(self):
+    def get_logcontext_id(self) -> str:
         return "RDATA-" + self.stream_name
 
 
@@ -164,18 +172,20 @@ class PositionCommand(Command):
 
     NAME = "POSITION"
 
-    def __init__(self, stream_name, instance_name, prev_token, new_token):
+    def __init__(
+        self, stream_name: str, instance_name: str, prev_token: int, new_token: int
+    ):
         self.stream_name = stream_name
         self.instance_name = instance_name
         self.prev_token = prev_token
         self.new_token = new_token
 
     @classmethod
-    def from_line(cls, line):
+    def from_line(cls: Type["PositionCommand"], line: str) -> "PositionCommand":
         stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
         return cls(stream_name, instance_name, int(prev_token), int(new_token))
 
-    def to_line(self):
+    def to_line(self) -> str:
         return " ".join(
             (
                 self.stream_name,
@@ -218,14 +228,14 @@ class ReplicateCommand(Command):
 
     NAME = "REPLICATE"
 
-    def __init__(self):
+    def __init__(self) -> None:
         pass
 
     @classmethod
-    def from_line(cls, line):
+    def from_line(cls: Type[T], line: str) -> T:
         return cls()
 
-    def to_line(self):
+    def to_line(self) -> str:
         return ""
 
 
@@ -247,14 +257,16 @@ class UserSyncCommand(Command):
 
     NAME = "USER_SYNC"
 
-    def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
+    def __init__(
+        self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+    ):
         self.instance_id = instance_id
         self.user_id = user_id
         self.is_syncing = is_syncing
         self.last_sync_ms = last_sync_ms
 
     @classmethod
-    def from_line(cls, line):
+    def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand":
         instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
 
         if state not in ("start", "end"):
@@ -262,7 +274,7 @@ class UserSyncCommand(Command):
 
         return cls(instance_id, user_id, state == "start", int(last_sync_ms))
 
-    def to_line(self):
+    def to_line(self) -> str:
         return " ".join(
             (
                 self.instance_id,
@@ -286,14 +298,16 @@ class ClearUserSyncsCommand(Command):
 
     NAME = "CLEAR_USER_SYNC"
 
-    def __init__(self, instance_id):
+    def __init__(self, instance_id: str):
         self.instance_id = instance_id
 
     @classmethod
-    def from_line(cls, line):
+    def from_line(
+        cls: Type["ClearUserSyncsCommand"], line: str
+    ) -> "ClearUserSyncsCommand":
         return cls(line)
 
-    def to_line(self):
+    def to_line(self) -> str:
         return self.instance_id
 
 
@@ -316,7 +330,9 @@ class FederationAckCommand(Command):
         self.token = token
 
     @classmethod
-    def from_line(cls, line: str) -> "FederationAckCommand":
+    def from_line(
+        cls: Type["FederationAckCommand"], line: str
+    ) -> "FederationAckCommand":
         instance_name, token = line.split(" ")
         return cls(instance_name, int(token))
 
@@ -334,7 +350,15 @@ class UserIpCommand(Command):
 
     NAME = "USER_IP"
 
-    def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+    def __init__(
+        self,
+        user_id: str,
+        access_token: str,
+        ip: str,
+        user_agent: str,
+        device_id: str,
+        last_seen: int,
+    ):
         self.user_id = user_id
         self.access_token = access_token
         self.ip = ip
@@ -343,14 +367,14 @@ class UserIpCommand(Command):
         self.last_seen = last_seen
 
     @classmethod
-    def from_line(cls, line):
+    def from_line(cls: Type["UserIpCommand"], line: str) -> "UserIpCommand":
         user_id, jsn = line.split(" ", 1)
 
         access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
 
         return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
 
-    def to_line(self):
+    def to_line(self) -> str:
         return (
             self.user_id
             + " "