summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/commands.py49
-rw-r--r--synapse/replication/tcp/handler.py4
-rw-r--r--synapse/replication/tcp/protocol.py14
3 files changed, 35 insertions, 32 deletions
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index e4eec643f7..5ec89d0fb8 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -17,7 +17,7 @@
 The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
 allowed to be sent by which side.
 """
-
+import abc
 import logging
 import platform
 from typing import Tuple, Type
@@ -34,34 +34,29 @@ else:
 logger = logging.getLogger(__name__)
 
 
-class Command(object):
+class Command(metaclass=abc.ABCMeta):
     """The base command class.
 
     All subclasses must set the NAME variable which equates to the name of the
     command on the wire.
 
     A full command line on the wire is constructed from `NAME + " " + to_line()`
-
-    The default implementation creates a command of form `<NAME> <data>`
     """
 
     NAME = None  # type: str
 
-    def __init__(self, data):
-        self.data = data
-
     @classmethod
+    @abc.abstractmethod
     def from_line(cls, line):
         """Deserialises a line from the wire into this command. `line` does not
         include the command.
         """
-        return cls(line)
 
-    def to_line(self):
+    @abc.abstractmethod
+    def to_line(self) -> str:
         """Serialises the comamnd for the wire. Does not include the command
         prefix.
         """
-        return self.data
 
     def get_logcontext_id(self):
         """Get a suitable string for the logcontext when processing this command"""
@@ -70,7 +65,21 @@ class Command(object):
         return self.NAME
 
 
-class ServerCommand(Command):
+class _SimpleCommand(Command):
+    """An implementation of Command whose argument is just a 'data' string."""
+
+    def __init__(self, data):
+        self.data = data
+
+    @classmethod
+    def from_line(cls, line):
+        return cls(line)
+
+    def to_line(self) -> str:
+        return self.data
+
+
+class ServerCommand(_SimpleCommand):
     """Sent by the server on new connection and includes the server_name.
 
     Format::
@@ -155,7 +164,7 @@ class PositionCommand(Command):
         return " ".join((self.stream_name, str(self.token)))
 
 
-class ErrorCommand(Command):
+class ErrorCommand(_SimpleCommand):
     """Sent by either side if there was an ERROR. The data is a string describing
     the error.
     """
@@ -163,14 +172,14 @@ class ErrorCommand(Command):
     NAME = "ERROR"
 
 
-class PingCommand(Command):
+class PingCommand(_SimpleCommand):
     """Sent by either side as a keep alive. The data is arbitary (often timestamp)
     """
 
     NAME = "PING"
 
 
-class NameCommand(Command):
+class NameCommand(_SimpleCommand):
     """Sent by client to inform the server of the client's identity. The data
     is the name
     """
@@ -289,14 +298,6 @@ class FederationAckCommand(Command):
         return str(self.token)
 
 
-class SyncCommand(Command):
-    """Used for testing. The client protocol implementation allows waiting
-    on a SYNC command with a specified data.
-    """
-
-    NAME = "SYNC"
-
-
 class RemovePusherCommand(Command):
     """Sent by the client to request the master remove the given pusher.
 
@@ -395,7 +396,7 @@ class UserIpCommand(Command):
         )
 
 
-class RemoteServerUpCommand(Command):
+class RemoteServerUpCommand(_SimpleCommand):
     """Sent when a worker has detected that a remote server is no longer
     "down" and retry timings should be reset.
 
@@ -419,7 +420,6 @@ _COMMANDS = (
     ReplicateCommand,
     UserSyncCommand,
     FederationAckCommand,
-    SyncCommand,
     RemovePusherCommand,
     InvalidateCacheCommand,
     UserIpCommand,
@@ -437,7 +437,6 @@ VALID_SERVER_COMMANDS = (
     PositionCommand.NAME,
     ErrorCommand.NAME,
     PingCommand.NAME,
-    SyncCommand.NAME,
     RemoteServerUpCommand.NAME,
 )
 
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index dd71d1bc34..2f5a299141 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -31,7 +31,6 @@ from synapse.replication.tcp.commands import (
     RemoteServerUpCommand,
     RemovePusherCommand,
     ReplicateCommand,
-    SyncCommand,
     UserIpCommand,
     UserSyncCommand,
 )
@@ -281,9 +280,6 @@ class ReplicationCommandHandler:
 
             self._streams_connected.add(cmd.stream_name)
 
-    async def on_SYNC(self, cmd: SyncCommand):
-        pass
-
     async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
         """"Called when get a new REMOTE_SERVER_UP command."""
         self._replication_data_handler.on_remote_server_up(cmd.data)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 9aabb9c586..9276ed2965 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -201,15 +201,23 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
                 )
                 self.send_error("ping timeout")
 
-    def lineReceived(self, line):
+    def lineReceived(self, line: bytes):
         """Called when we've received a line
         """
         if line.strip() == "":
             # Ignore blank lines
             return
 
-        line = line.decode("utf-8")
-        cmd_name, rest_of_line = line.split(" ", 1)
+        linestr = line.decode("utf-8")
+
+        # split at the first " ", handling one-word commands
+        idx = linestr.index(" ")
+        if idx >= 0:
+            cmd_name = linestr[:idx]
+            rest_of_line = linestr[idx + 1 :]
+        else:
+            cmd_name = linestr
+            rest_of_line = ""
 
         if cmd_name not in self.VALID_INBOUND_COMMANDS:
             logger.error("[%s] invalid command %s", self.id(), cmd_name)