diff options
Diffstat (limited to 'synapse/replication/tcp/commands.py')
-rw-r--r-- | synapse/replication/tcp/commands.py | 181 |
1 files changed, 113 insertions, 68 deletions
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 451671412d..f58e384d17 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -17,7 +17,7 @@ The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are allowed to be sent by which side. """ - +import abc import logging import platform from typing import Tuple, Type @@ -34,34 +34,29 @@ else: logger = logging.getLogger(__name__) -class Command(object): +class Command(metaclass=abc.ABCMeta): """The base command class. All subclasses must set the NAME variable which equates to the name of the command on the wire. A full command line on the wire is constructed from `NAME + " " + to_line()` - - The default implementation creates a command of form `<NAME> <data>` """ NAME = None # type: str - def __init__(self, data): - self.data = data - @classmethod + @abc.abstractmethod def from_line(cls, line): """Deserialises a line from the wire into this command. `line` does not include the command. """ - return cls(line) - def to_line(self): + @abc.abstractmethod + def to_line(self) -> str: """Serialises the comamnd for the wire. Does not include the command prefix. """ - return self.data def get_logcontext_id(self): """Get a suitable string for the logcontext when processing this command""" @@ -70,7 +65,21 @@ class Command(object): return self.NAME -class ServerCommand(Command): +class _SimpleCommand(Command): + """An implementation of Command whose argument is just a 'data' string.""" + + def __init__(self, data): + self.data = data + + @classmethod + def from_line(cls, line): + return cls(line) + + def to_line(self) -> str: + return self.data + + +class ServerCommand(_SimpleCommand): """Sent by the server on new connection and includes the server_name. Format:: @@ -86,7 +95,7 @@ class RdataCommand(Command): Format:: - RDATA <stream_name> <token> <row_json> + RDATA <stream_name> <instance_name> <token> <row_json> The `<token>` may either be a numeric stream id OR "batch". The latter case is used to support sending multiple updates with the same stream ID. This @@ -96,33 +105,40 @@ class RdataCommand(Command): The client should batch all incoming RDATA with a token of "batch" (per stream_name) until it sees an RDATA with a numeric stream ID. + The `<instance_name>` is the source of the new data (usually "master"). + `<token>` of "batch" maps to the instance variable `token` being None. An example of a batched series of RDATA:: - RDATA presence batch ["@foo:example.com", "online", ...] - RDATA presence batch ["@bar:example.com", "online", ...] - RDATA presence 59 ["@baz:example.com", "online", ...] + RDATA presence master batch ["@foo:example.com", "online", ...] + RDATA presence master batch ["@bar:example.com", "online", ...] + RDATA presence master 59 ["@baz:example.com", "online", ...] """ NAME = "RDATA" - def __init__(self, stream_name, token, row): + def __init__(self, stream_name, instance_name, token, row): self.stream_name = stream_name + self.instance_name = instance_name self.token = token self.row = row @classmethod def from_line(cls, line): - stream_name, token, row_json = line.split(" ", 2) + stream_name, instance_name, token, row_json = line.split(" ", 3) return cls( - stream_name, None if token == "batch" else int(token), json.loads(row_json) + stream_name, + instance_name, + None if token == "batch" else int(token), + json.loads(row_json), ) def to_line(self): return " ".join( ( self.stream_name, + self.instance_name, str(self.token) if self.token is not None else "batch", _json_encoder.encode(self.row), ) @@ -136,26 +152,34 @@ class PositionCommand(Command): """Sent by the server to tell the client the stream postition without needing to send an RDATA. - Sent to the client after all missing updates for a stream have been sent - to the client and they're now up to date. + Format:: + + POSITION <stream_name> <instance_name> <token> + + On receipt of a POSITION command clients should check if they have missed + any updates, and if so then fetch them out of band. + + The `<instance_name>` is the process that sent the command and is the source + of the stream. """ NAME = "POSITION" - def __init__(self, stream_name, token): + def __init__(self, stream_name, instance_name, token): self.stream_name = stream_name + self.instance_name = instance_name self.token = token @classmethod def from_line(cls, line): - stream_name, token = line.split(" ", 1) - return cls(stream_name, int(token)) + stream_name, instance_name, token = line.split(" ", 2) + return cls(stream_name, instance_name, int(token)) def to_line(self): - return " ".join((self.stream_name, str(self.token))) + return " ".join((self.stream_name, self.instance_name, str(self.token))) -class ErrorCommand(Command): +class ErrorCommand(_SimpleCommand): """Sent by either side if there was an ERROR. The data is a string describing the error. """ @@ -163,14 +187,14 @@ class ErrorCommand(Command): NAME = "ERROR" -class PingCommand(Command): +class PingCommand(_SimpleCommand): """Sent by either side as a keep alive. The data is arbitary (often timestamp) """ NAME = "PING" -class NameCommand(Command): +class NameCommand(_SimpleCommand): """Sent by client to inform the server of the client's identity. The data is the name """ @@ -179,76 +203,63 @@ class NameCommand(Command): class ReplicateCommand(Command): - """Sent by the client to subscribe to the stream. + """Sent by the client to subscribe to streams. Format:: - REPLICATE <stream_name> <token> - - Where <token> may be either: - * a numeric stream_id to stream updates from - * "NOW" to stream all subsequent updates. - - The <stream_name> can be "ALL" to subscribe to all known streams, in which - case the <token> must be set to "NOW", i.e.:: - - REPLICATE ALL NOW + REPLICATE """ NAME = "REPLICATE" - def __init__(self, stream_name, token): - self.stream_name = stream_name - self.token = token + def __init__(self): + pass @classmethod def from_line(cls, line): - stream_name, token = line.split(" ", 1) - if token in ("NOW", "now"): - token = "NOW" - else: - token = int(token) - return cls(stream_name, token) + return cls() def to_line(self): - return " ".join((self.stream_name, str(self.token))) - - def get_logcontext_id(self): - return "REPLICATE-" + self.stream_name + return "" class UserSyncCommand(Command): """Sent by the client to inform the server that a user has started or - stopped syncing. Used to calculate presence on the master. + stopped syncing on this process. + + This is used by the process handling presence (typically the master) to + calculate who is online and who is not. Includes a timestamp of when the last user sync was. Format:: - USER_SYNC <user_id> <state> <last_sync_ms> + USER_SYNC <instance_id> <user_id> <state> <last_sync_ms> - Where <state> is either "start" or "stop" + Where <state> is either "start" or "end" """ NAME = "USER_SYNC" - def __init__(self, user_id, is_syncing, last_sync_ms): + def __init__(self, instance_id, user_id, is_syncing, last_sync_ms): + self.instance_id = instance_id self.user_id = user_id self.is_syncing = is_syncing self.last_sync_ms = last_sync_ms @classmethod def from_line(cls, line): - user_id, state, last_sync_ms = line.split(" ", 2) + instance_id, user_id, state, last_sync_ms = line.split(" ", 3) if state not in ("start", "end"): raise Exception("Invalid USER_SYNC state %r" % (state,)) - return cls(user_id, state == "start", int(last_sync_ms)) + return cls(instance_id, user_id, state == "start", int(last_sync_ms)) def to_line(self): return " ".join( ( + self.instance_id, self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms), @@ -256,6 +267,30 @@ class UserSyncCommand(Command): ) +class ClearUserSyncsCommand(Command): + """Sent by the client to inform the server that it should drop all + information about syncing users sent by the client. + + Mainly used when client is about to shut down. + + Format:: + + CLEAR_USER_SYNC <instance_id> + """ + + NAME = "CLEAR_USER_SYNC" + + def __init__(self, instance_id): + self.instance_id = instance_id + + @classmethod + def from_line(cls, line): + return cls(line) + + def to_line(self): + return self.instance_id + + class FederationAckCommand(Command): """Sent by the client when it has processed up to a given point in the federation stream. This allows the master to drop in-memory caches of the @@ -281,14 +316,6 @@ class FederationAckCommand(Command): return str(self.token) -class SyncCommand(Command): - """Used for testing. The client protocol implementation allows waiting - on a SYNC command with a specified data. - """ - - NAME = "SYNC" - - class RemovePusherCommand(Command): """Sent by the client to request the master remove the given pusher. @@ -387,7 +414,7 @@ class UserIpCommand(Command): ) -class RemoteServerUpCommand(Command): +class RemoteServerUpCommand(_SimpleCommand): """Sent when a worker has detected that a remote server is no longer "down" and retry timings should be reset. @@ -411,11 +438,11 @@ _COMMANDS = ( ReplicateCommand, UserSyncCommand, FederationAckCommand, - SyncCommand, RemovePusherCommand, InvalidateCacheCommand, UserIpCommand, RemoteServerUpCommand, + ClearUserSyncsCommand, ) # type: Tuple[Type[Command], ...] # Map of command name to command type. @@ -428,7 +455,6 @@ VALID_SERVER_COMMANDS = ( PositionCommand.NAME, ErrorCommand.NAME, PingCommand.NAME, - SyncCommand.NAME, RemoteServerUpCommand.NAME, ) @@ -438,6 +464,7 @@ VALID_CLIENT_COMMANDS = ( ReplicateCommand.NAME, PingCommand.NAME, UserSyncCommand.NAME, + ClearUserSyncsCommand.NAME, FederationAckCommand.NAME, RemovePusherCommand.NAME, InvalidateCacheCommand.NAME, @@ -445,3 +472,21 @@ VALID_CLIENT_COMMANDS = ( ErrorCommand.NAME, RemoteServerUpCommand.NAME, ) + + +def parse_command_from_line(line: str) -> Command: + """Parses a command from a received line. + + Line should already be stripped of whitespace and be checked if blank. + """ + + idx = line.find(" ") + if idx >= 0: + cmd_name = line[:idx] + rest_of_line = line[idx + 1 :] + else: + cmd_name = line + rest_of_line = "" + + cmd_cls = COMMAND_MAP[cmd_name] + return cmd_cls.from_line(rest_of_line) |