diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 0ff2a7199f..c04f622816 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -17,50 +17,46 @@
The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
allowed to be sent by which side.
"""
-
+import abc
import logging
import platform
+from typing import Tuple, Type
if platform.python_implementation() == "PyPy":
import json
_json_encoder = json.JSONEncoder()
else:
- import simplejson as json
+ import simplejson as json # type: ignore[no-redef] # noqa: F821
- _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
+ _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821
logger = logging.getLogger(__name__)
-class Command(object):
+class Command(metaclass=abc.ABCMeta):
"""The base command class.
All subclasses must set the NAME variable which equates to the name of the
command on the wire.
A full command line on the wire is constructed from `NAME + " " + to_line()`
-
- The default implementation creates a command of form `<NAME> <data>`
"""
- NAME = None
-
- def __init__(self, data):
- self.data = data
+ NAME = None # type: str
@classmethod
+ @abc.abstractmethod
def from_line(cls, line):
"""Deserialises a line from the wire into this command. `line` does not
include the command.
"""
- return cls(line)
- def to_line(self):
+ @abc.abstractmethod
+ def to_line(self) -> str:
"""Serialises the comamnd for the wire. Does not include the command
prefix.
"""
- return self.data
def get_logcontext_id(self):
"""Get a suitable string for the logcontext when processing this command"""
@@ -69,7 +65,21 @@ class Command(object):
return self.NAME
-class ServerCommand(Command):
+class _SimpleCommand(Command):
+ """An implementation of Command whose argument is just a 'data' string."""
+
+ def __init__(self, data):
+ self.data = data
+
+ @classmethod
+ def from_line(cls, line):
+ return cls(line)
+
+ def to_line(self) -> str:
+ return self.data
+
+
+class ServerCommand(_SimpleCommand):
"""Sent by the server on new connection and includes the server_name.
Format::
@@ -85,7 +95,7 @@ class RdataCommand(Command):
Format::
- RDATA <stream_name> <token> <row_json>
+ RDATA <stream_name> <instance_name> <token> <row_json>
The `<token>` may either be a numeric stream id OR "batch". The latter case
is used to support sending multiple updates with the same stream ID. This
@@ -95,33 +105,40 @@ class RdataCommand(Command):
The client should batch all incoming RDATA with a token of "batch" (per
stream_name) until it sees an RDATA with a numeric stream ID.
+ The `<instance_name>` is the source of the new data (usually "master").
+
`<token>` of "batch" maps to the instance variable `token` being None.
An example of a batched series of RDATA::
- RDATA presence batch ["@foo:example.com", "online", ...]
- RDATA presence batch ["@bar:example.com", "online", ...]
- RDATA presence 59 ["@baz:example.com", "online", ...]
+ RDATA presence master batch ["@foo:example.com", "online", ...]
+ RDATA presence master batch ["@bar:example.com", "online", ...]
+ RDATA presence master 59 ["@baz:example.com", "online", ...]
"""
NAME = "RDATA"
- def __init__(self, stream_name, token, row):
+ def __init__(self, stream_name, instance_name, token, row):
self.stream_name = stream_name
+ self.instance_name = instance_name
self.token = token
self.row = row
@classmethod
def from_line(cls, line):
- stream_name, token, row_json = line.split(" ", 2)
+ stream_name, instance_name, token, row_json = line.split(" ", 3)
return cls(
- stream_name, None if token == "batch" else int(token), json.loads(row_json)
+ stream_name,
+ instance_name,
+ None if token == "batch" else int(token),
+ json.loads(row_json),
)
def to_line(self):
return " ".join(
(
self.stream_name,
+ self.instance_name,
str(self.token) if self.token is not None else "batch",
_json_encoder.encode(self.row),
)
@@ -135,26 +152,34 @@ class PositionCommand(Command):
"""Sent by the server to tell the client the stream postition without
needing to send an RDATA.
- Sent to the client after all missing updates for a stream have been sent
- to the client and they're now up to date.
+ Format::
+
+ POSITION <stream_name> <instance_name> <token>
+
+ On receipt of a POSITION command clients should check if they have missed
+ any updates, and if so then fetch them out of band.
+
+ The `<instance_name>` is the process that sent the command and is the source
+ of the stream.
"""
NAME = "POSITION"
- def __init__(self, stream_name, token):
+ def __init__(self, stream_name, instance_name, token):
self.stream_name = stream_name
+ self.instance_name = instance_name
self.token = token
@classmethod
def from_line(cls, line):
- stream_name, token = line.split(" ", 1)
- return cls(stream_name, int(token))
+ stream_name, instance_name, token = line.split(" ", 2)
+ return cls(stream_name, instance_name, int(token))
def to_line(self):
- return " ".join((self.stream_name, str(self.token)))
+ return " ".join((self.stream_name, self.instance_name, str(self.token)))
-class ErrorCommand(Command):
+class ErrorCommand(_SimpleCommand):
"""Sent by either side if there was an ERROR. The data is a string describing
the error.
"""
@@ -162,14 +187,14 @@ class ErrorCommand(Command):
NAME = "ERROR"
-class PingCommand(Command):
+class PingCommand(_SimpleCommand):
"""Sent by either side as a keep alive. The data is arbitary (often timestamp)
"""
NAME = "PING"
-class NameCommand(Command):
+class NameCommand(_SimpleCommand):
"""Sent by client to inform the server of the client's identity. The data
is the name
"""
@@ -178,76 +203,63 @@ class NameCommand(Command):
class ReplicateCommand(Command):
- """Sent by the client to subscribe to the stream.
+ """Sent by the client to subscribe to streams.
Format::
- REPLICATE <stream_name> <token>
-
- Where <token> may be either:
- * a numeric stream_id to stream updates from
- * "NOW" to stream all subsequent updates.
-
- The <stream_name> can be "ALL" to subscribe to all known streams, in which
- case the <token> must be set to "NOW", i.e.::
-
- REPLICATE ALL NOW
+ REPLICATE
"""
NAME = "REPLICATE"
- def __init__(self, stream_name, token):
- self.stream_name = stream_name
- self.token = token
+ def __init__(self):
+ pass
@classmethod
def from_line(cls, line):
- stream_name, token = line.split(" ", 1)
- if token in ("NOW", "now"):
- token = "NOW"
- else:
- token = int(token)
- return cls(stream_name, token)
+ return cls()
def to_line(self):
- return " ".join((self.stream_name, str(self.token)))
-
- def get_logcontext_id(self):
- return "REPLICATE-" + self.stream_name
+ return ""
class UserSyncCommand(Command):
"""Sent by the client to inform the server that a user has started or
- stopped syncing. Used to calculate presence on the master.
+ stopped syncing on this process.
+
+ This is used by the process handling presence (typically the master) to
+ calculate who is online and who is not.
Includes a timestamp of when the last user sync was.
Format::
- USER_SYNC <user_id> <state> <last_sync_ms>
+ USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
- Where <state> is either "start" or "stop"
+ Where <state> is either "start" or "end"
"""
NAME = "USER_SYNC"
- def __init__(self, user_id, is_syncing, last_sync_ms):
+ def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
+ self.instance_id = instance_id
self.user_id = user_id
self.is_syncing = is_syncing
self.last_sync_ms = last_sync_ms
@classmethod
def from_line(cls, line):
- user_id, state, last_sync_ms = line.split(" ", 2)
+ instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
if state not in ("start", "end"):
raise Exception("Invalid USER_SYNC state %r" % (state,))
- return cls(user_id, state == "start", int(last_sync_ms))
+ return cls(instance_id, user_id, state == "start", int(last_sync_ms))
def to_line(self):
return " ".join(
(
+ self.instance_id,
self.user_id,
"start" if self.is_syncing else "end",
str(self.last_sync_ms),
@@ -255,6 +267,30 @@ class UserSyncCommand(Command):
)
+class ClearUserSyncsCommand(Command):
+ """Sent by the client to inform the server that it should drop all
+ information about syncing users sent by the client.
+
+ Mainly used when client is about to shut down.
+
+ Format::
+
+ CLEAR_USER_SYNC <instance_id>
+ """
+
+ NAME = "CLEAR_USER_SYNC"
+
+ def __init__(self, instance_id):
+ self.instance_id = instance_id
+
+ @classmethod
+ def from_line(cls, line):
+ return cls(line)
+
+ def to_line(self):
+ return self.instance_id
+
+
class FederationAckCommand(Command):
"""Sent by the client when it has processed up to a given point in the
federation stream. This allows the master to drop in-memory caches of the
@@ -280,14 +316,6 @@ class FederationAckCommand(Command):
return str(self.token)
-class SyncCommand(Command):
- """Used for testing. The client protocol implementation allows waiting
- on a SYNC command with a specified data.
- """
-
- NAME = "SYNC"
-
-
class RemovePusherCommand(Command):
"""Sent by the client to request the master remove the given pusher.
@@ -313,37 +341,6 @@ class RemovePusherCommand(Command):
return " ".join((self.app_id, self.push_key, self.user_id))
-class InvalidateCacheCommand(Command):
- """Sent by the client to invalidate an upstream cache.
-
- THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
- NOT DISASTROUS IF WE DROP ON THE FLOOR.
-
- Mainly used to invalidate destination retry timing caches.
-
- Format::
-
- INVALIDATE_CACHE <cache_func> <keys_json>
-
- Where <keys_json> is a json list.
- """
-
- NAME = "INVALIDATE_CACHE"
-
- def __init__(self, cache_func, keys):
- self.cache_func = cache_func
- self.keys = keys
-
- @classmethod
- def from_line(cls, line):
- cache_func, keys_json = line.split(" ", 1)
-
- return cls(cache_func, json.loads(keys_json))
-
- def to_line(self):
- return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
-
-
class UserIpCommand(Command):
"""Sent periodically when a worker sees activity from a client.
@@ -386,25 +383,38 @@ class UserIpCommand(Command):
)
+class RemoteServerUpCommand(_SimpleCommand):
+ """Sent when a worker has detected that a remote server is no longer
+ "down" and retry timings should be reset.
+
+ If sent from a client the server will relay to all other workers.
+
+ Format::
+
+ REMOTE_SERVER_UP <server>
+ """
+
+ NAME = "REMOTE_SERVER_UP"
+
+
+_COMMANDS = (
+ ServerCommand,
+ RdataCommand,
+ PositionCommand,
+ ErrorCommand,
+ PingCommand,
+ NameCommand,
+ ReplicateCommand,
+ UserSyncCommand,
+ FederationAckCommand,
+ RemovePusherCommand,
+ UserIpCommand,
+ RemoteServerUpCommand,
+ ClearUserSyncsCommand,
+) # type: Tuple[Type[Command], ...]
+
# Map of command name to command type.
-COMMAND_MAP = {
- cmd.NAME: cmd
- for cmd in (
- ServerCommand,
- RdataCommand,
- PositionCommand,
- ErrorCommand,
- PingCommand,
- NameCommand,
- ReplicateCommand,
- UserSyncCommand,
- FederationAckCommand,
- SyncCommand,
- RemovePusherCommand,
- InvalidateCacheCommand,
- UserIpCommand,
- )
-}
+COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}
# The commands the server is allowed to send
VALID_SERVER_COMMANDS = (
@@ -413,7 +423,7 @@ VALID_SERVER_COMMANDS = (
PositionCommand.NAME,
ErrorCommand.NAME,
PingCommand.NAME,
- SyncCommand.NAME,
+ RemoteServerUpCommand.NAME,
)
# The commands the client is allowed to send
@@ -422,9 +432,28 @@ VALID_CLIENT_COMMANDS = (
ReplicateCommand.NAME,
PingCommand.NAME,
UserSyncCommand.NAME,
+ ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
RemovePusherCommand.NAME,
- InvalidateCacheCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
+ RemoteServerUpCommand.NAME,
)
+
+
+def parse_command_from_line(line: str) -> Command:
+ """Parses a command from a received line.
+
+ Line should already be stripped of whitespace and be checked if blank.
+ """
+
+ idx = line.find(" ")
+ if idx >= 0:
+ cmd_name = line[:idx]
+ rest_of_line = line[idx + 1 :]
+ else:
+ cmd_name = line
+ rest_of_line = ""
+
+ cmd_cls = COMMAND_MAP[cmd_name]
+ return cmd_cls.from_line(rest_of_line)
|