diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 1311b013da..3654f6c03c 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -18,12 +18,15 @@ allowed to be sent by which side.
"""
import abc
import logging
-from typing import Tuple, Type
+from typing import Optional, Tuple, Type, TypeVar
+from synapse.replication.tcp.streams._base import StreamRow
from synapse.util import json_decoder, json_encoder
logger = logging.getLogger(__name__)
+T = TypeVar("T", bound="Command")
+
class Command(metaclass=abc.ABCMeta):
"""The base command class.
@@ -38,7 +41,7 @@ class Command(metaclass=abc.ABCMeta):
@classmethod
@abc.abstractmethod
- def from_line(cls, line):
+ def from_line(cls: Type[T], line: str) -> T:
"""Deserialises a line from the wire into this command. `line` does not
include the command.
"""
@@ -49,21 +52,24 @@ class Command(metaclass=abc.ABCMeta):
prefix.
"""
- def get_logcontext_id(self):
+ def get_logcontext_id(self) -> str:
"""Get a suitable string for the logcontext when processing this command"""
# by default, we just use the command name.
return self.NAME
+SC = TypeVar("SC", bound="_SimpleCommand")
+
+
class _SimpleCommand(Command):
"""An implementation of Command whose argument is just a 'data' string."""
- def __init__(self, data):
+ def __init__(self, data: str):
self.data = data
@classmethod
- def from_line(cls, line):
+ def from_line(cls: Type[SC], line: str) -> SC:
return cls(line)
def to_line(self) -> str:
@@ -109,14 +115,16 @@ class RdataCommand(Command):
NAME = "RDATA"
- def __init__(self, stream_name, instance_name, token, row):
+ def __init__(
+ self, stream_name: str, instance_name: str, token: Optional[int], row: StreamRow
+ ):
self.stream_name = stream_name
self.instance_name = instance_name
self.token = token
self.row = row
@classmethod
- def from_line(cls, line):
+ def from_line(cls: Type["RdataCommand"], line: str) -> "RdataCommand":
stream_name, instance_name, token, row_json = line.split(" ", 3)
return cls(
stream_name,
@@ -125,7 +133,7 @@ class RdataCommand(Command):
json_decoder.decode(row_json),
)
- def to_line(self):
+ def to_line(self) -> str:
return " ".join(
(
self.stream_name,
@@ -135,7 +143,7 @@ class RdataCommand(Command):
)
)
- def get_logcontext_id(self):
+ def get_logcontext_id(self) -> str:
return "RDATA-" + self.stream_name
@@ -164,18 +172,20 @@ class PositionCommand(Command):
NAME = "POSITION"
- def __init__(self, stream_name, instance_name, prev_token, new_token):
+ def __init__(
+ self, stream_name: str, instance_name: str, prev_token: int, new_token: int
+ ):
self.stream_name = stream_name
self.instance_name = instance_name
self.prev_token = prev_token
self.new_token = new_token
@classmethod
- def from_line(cls, line):
+ def from_line(cls: Type["PositionCommand"], line: str) -> "PositionCommand":
stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
return cls(stream_name, instance_name, int(prev_token), int(new_token))
- def to_line(self):
+ def to_line(self) -> str:
return " ".join(
(
self.stream_name,
@@ -218,14 +228,14 @@ class ReplicateCommand(Command):
NAME = "REPLICATE"
- def __init__(self):
+ def __init__(self) -> None:
pass
@classmethod
- def from_line(cls, line):
+ def from_line(cls: Type[T], line: str) -> T:
return cls()
- def to_line(self):
+ def to_line(self) -> str:
return ""
@@ -247,14 +257,16 @@ class UserSyncCommand(Command):
NAME = "USER_SYNC"
- def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
+ def __init__(
+ self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+ ):
self.instance_id = instance_id
self.user_id = user_id
self.is_syncing = is_syncing
self.last_sync_ms = last_sync_ms
@classmethod
- def from_line(cls, line):
+ def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand":
instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
if state not in ("start", "end"):
@@ -262,7 +274,7 @@ class UserSyncCommand(Command):
return cls(instance_id, user_id, state == "start", int(last_sync_ms))
- def to_line(self):
+ def to_line(self) -> str:
return " ".join(
(
self.instance_id,
@@ -286,14 +298,16 @@ class ClearUserSyncsCommand(Command):
NAME = "CLEAR_USER_SYNC"
- def __init__(self, instance_id):
+ def __init__(self, instance_id: str):
self.instance_id = instance_id
@classmethod
- def from_line(cls, line):
+ def from_line(
+ cls: Type["ClearUserSyncsCommand"], line: str
+ ) -> "ClearUserSyncsCommand":
return cls(line)
- def to_line(self):
+ def to_line(self) -> str:
return self.instance_id
@@ -316,7 +330,9 @@ class FederationAckCommand(Command):
self.token = token
@classmethod
- def from_line(cls, line: str) -> "FederationAckCommand":
+ def from_line(
+ cls: Type["FederationAckCommand"], line: str
+ ) -> "FederationAckCommand":
instance_name, token = line.split(" ")
return cls(instance_name, int(token))
@@ -334,7 +350,15 @@ class UserIpCommand(Command):
NAME = "USER_IP"
- def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+ def __init__(
+ self,
+ user_id: str,
+ access_token: str,
+ ip: str,
+ user_agent: str,
+ device_id: str,
+ last_seen: int,
+ ):
self.user_id = user_id
self.access_token = access_token
self.ip = ip
@@ -343,14 +367,14 @@ class UserIpCommand(Command):
self.last_seen = last_seen
@classmethod
- def from_line(cls, line):
+ def from_line(cls: Type["UserIpCommand"], line: str) -> "UserIpCommand":
user_id, jsn = line.split(" ", 1)
access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
- def to_line(self):
+ def to_line(self) -> str:
return (
self.user_id
+ " "
|