diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e165429cad..2618eb1e53 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -141,21 +141,25 @@ class ReplicationDataHandler:
if row.type != EventsStreamEventRow.TypeId:
continue
assert isinstance(row, EventsStreamRow)
+ assert isinstance(row.data, EventsStreamEventRow)
- event = await self.store.get_event(
- row.data.event_id, allow_rejected=True
- )
- if event.rejected_reason:
+ if row.data.rejected:
continue
extra_users = () # type: Tuple[UserID, ...]
- if event.type == EventTypes.Member:
- extra_users = (UserID.from_string(event.state_key),)
+ if row.data.type == EventTypes.Member and row.data.state_key:
+ extra_users = (UserID.from_string(row.data.state_key),)
max_token = self.store.get_room_max_token()
event_pos = PersistedEventPosition(instance_name, token)
- self.notifier.on_new_room_event(
- event, event_pos, max_token, extra_users
+ self.notifier.on_new_room_event_args(
+ event_pos=event_pos,
+ max_room_stream_token=max_token,
+ extra_users=extra_users,
+ room_id=row.data.room_id,
+ event_type=row.data.type,
+ state_key=row.data.state_key,
+ membership=row.data.membership,
)
# Notify any waiting deferreds. The list is ordered by position so we
@@ -191,6 +195,10 @@ class ReplicationDataHandler:
async def on_position(self, stream_name: str, instance_name: str, token: int):
self.store.process_replication_rows(stream_name, instance_name, token, [])
+ # We poke the generic "replication" notifier to wake anything up that
+ # may be streaming.
+ self.notifier.notify_replication()
+
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 8cd47770c1..ac532ed588 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -141,15 +141,23 @@ class RdataCommand(Command):
class PositionCommand(Command):
- """Sent by the server to tell the client the stream position without
- needing to send an RDATA.
+ """Sent by an instance to tell others the stream position without needing to
+ send an RDATA.
+
+ Two tokens are sent, the new position and the last position sent by the
+ instance (in an RDATA or other POSITION). The tokens are chosen so that *no*
+ rows were written by the instance between the `prev_token` and `new_token`.
+ (If an instance hasn't sent a position before then the new position can be
+ used for both.)
Format::
- POSITION <stream_name> <instance_name> <token>
+ POSITION <stream_name> <instance_name> <prev_token> <new_token>
- On receipt of a POSITION command clients should check if they have missed
- any updates, and if so then fetch them out of band.
+ On receipt of a POSITION command instances should check if they have missed
+ any updates, and if so then fetch them out of band. Instances can check this
+ by comparing their view of the current token for the sending instance with
+ the included `prev_token`.
The `<instance_name>` is the process that sent the command and is the source
of the stream.
@@ -157,18 +165,26 @@ class PositionCommand(Command):
NAME = "POSITION"
- def __init__(self, stream_name, instance_name, token):
+ def __init__(self, stream_name, instance_name, prev_token, new_token):
self.stream_name = stream_name
self.instance_name = instance_name
- self.token = token
+ self.prev_token = prev_token
+ self.new_token = new_token
@classmethod
def from_line(cls, line):
- stream_name, instance_name, token = line.split(" ", 2)
- return cls(stream_name, instance_name, int(token))
+ stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
+ return cls(stream_name, instance_name, int(prev_token), int(new_token))
def to_line(self):
- return " ".join((self.stream_name, self.instance_name, str(self.token)))
+ return " ".join(
+ (
+ self.stream_name,
+ self.instance_name,
+ str(self.prev_token),
+ str(self.new_token),
+ )
+ )
class ErrorCommand(_SimpleCommand):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b323841f73..95e5502bf2 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -101,8 +101,9 @@ class ReplicationCommandHandler:
self._streams_to_replicate = [] # type: List[Stream]
for stream in self._streams.values():
- if stream.NAME == CachesStream.NAME:
- # All workers can write to the cache invalidation stream.
+ if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME:
+ # All workers can write to the cache invalidation stream when
+ # using redis.
self._streams_to_replicate.append(stream)
continue
@@ -251,10 +252,9 @@ class ReplicationCommandHandler:
using TCP.
"""
if hs.config.redis.redis_enabled:
- import txredisapi
-
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
+ lazyConnection,
)
logger.info(
@@ -271,7 +271,8 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called).
# First create the connection for sending commands.
- outbound_redis_connection = txredisapi.lazyConnection(
+ outbound_redis_connection = lazyConnection(
+ reactor=hs.get_reactor(),
host=hs.config.redis_host,
port=hs.config.redis_port,
password=hs.config.redis.redis_password,
@@ -313,11 +314,14 @@ class ReplicationCommandHandler:
# We respond with current position of all streams this instance
# replicates.
for stream in self.get_streams_to_replicate():
+ # Note that we use the current token as the prev token here (rather
+ # than stream.last_token), as we can't be sure that there have been
+ # no rows written between last token and the current token (since we
+ # might be racing with the replication sending bg process).
+ current_token = stream.current_token(self._instance_name)
self.send_command(
PositionCommand(
- stream.NAME,
- self._instance_name,
- stream.current_token(self._instance_name),
+ stream.NAME, self._instance_name, current_token, current_token,
)
)
@@ -511,16 +515,16 @@ class ReplicationCommandHandler:
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
# between then and now.
- missing_updates = cmd.token != current_token
+ missing_updates = cmd.prev_token != current_token
while missing_updates:
logger.info(
"Fetching replication rows for '%s' between %i and %i",
stream_name,
current_token,
- cmd.token,
+ cmd.new_token,
)
(updates, current_token, missing_updates) = await stream.get_updates_since(
- cmd.instance_name, current_token, cmd.token
+ cmd.instance_name, current_token, cmd.new_token
)
# TODO: add some tests for this
@@ -536,11 +540,11 @@ class ReplicationCommandHandler:
[stream.parse_row(row) for row in rows],
)
- logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+ logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token)
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(
- cmd.stream_name, cmd.instance_name, cmd.token
+ cmd.stream_name, cmd.instance_name, cmd.new_token
)
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 0b0d204e64..a509e599c2 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -51,10 +51,11 @@ import fcntl
import logging
import struct
from inspect import isawaitable
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Optional
from prometheus_client import Counter
+from twisted.internet import task
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
@@ -152,9 +153,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0
- self.time_we_closed = None # When we requested the connection be closed
+ # When we requested the connection be closed
+ self.time_we_closed = None # type: Optional[int]
- self.received_ping = False # Have we reecived a ping from the other side
+ self.received_ping = False # Have we received a ping from the other side
self.state = ConnectionStates.CONNECTING
@@ -165,7 +167,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.pending_commands = [] # type: List[Command]
# The LoopingCall for sending pings.
- self._send_ping_loop = None
+ self._send_ping_loop = None # type: Optional[task.LoopingCall]
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index f225e533de..bc6ba709a7 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
import logging
from inspect import isawaitable
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
import txredisapi
@@ -166,7 +166,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
Args:
cmd (Command)
"""
- run_as_background_process("send-cmd", self._async_send_command, cmd)
+ run_as_background_process(
+ "send-cmd", self._async_send_command, cmd, bg_start_span=False
+ )
async def _async_send_command(self, cmd: Command):
"""Encode a replication command and send it over our outbound connection"""
@@ -228,3 +230,41 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
p.password = self.password
return p
+
+
+def lazyConnection(
+ reactor,
+ host: str = "localhost",
+ port: int = 6379,
+ dbid: Optional[int] = None,
+ reconnect: bool = True,
+ charset: str = "utf-8",
+ password: Optional[str] = None,
+ connectTimeout: Optional[int] = None,
+ replyTimeout: Optional[int] = None,
+ convertNumbers: bool = True,
+) -> txredisapi.RedisProtocol:
+ """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
+ reactor.
+ """
+
+ isLazy = True
+ poolsize = 1
+
+ uuid = "%s:%d" % (host, port)
+ factory = txredisapi.RedisFactory(
+ uuid,
+ dbid,
+ poolsize,
+ isLazy,
+ txredisapi.ConnectionHandler,
+ charset,
+ password,
+ replyTimeout,
+ convertNumbers,
+ )
+ factory.continueTrying = reconnect
+ for x in range(poolsize):
+ reactor.connectTCP(host, port, factory, connectTimeout)
+
+ return factory.handler
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 687984e7a8..1d4ceac0f1 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -23,7 +23,9 @@ from prometheus_client import Counter
from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.tcp.commands import PositionCommand
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
+from synapse.replication.tcp.streams import EventsStream
from synapse.util.metrics import Measure
stream_updates_counter = Counter(
@@ -84,6 +86,23 @@ class ReplicationStreamer:
# Set of streams to replicate.
self.streams = self.command_handler.get_streams_to_replicate()
+ # If we have streams then we must have redis enabled or on master
+ assert (
+ not self.streams
+ or hs.config.redis.redis_enabled
+ or not hs.config.worker.worker_app
+ )
+
+ # If we are replicating an event stream we want to periodically check if
+ # we should send updated POSITIONs. We do this as a looping call rather
+ # explicitly poking when the position advances (without new data to
+ # replicate) to reduce replication traffic (otherwise each writer would
+ # likely send a POSITION for each new event received over replication).
+ #
+ # Note that if the position hasn't advanced then we won't send anything.
+ if any(EventsStream.NAME == s.NAME for s in self.streams):
+ self.clock.looping_call(self.on_notifier_poke, 1000)
+
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
connections if there are.
@@ -91,13 +110,23 @@ class ReplicationStreamer:
This should get called each time new data is available, even if it
is currently being executed, so that nothing gets missed
"""
- if not self.command_handler.connected():
+ if not self.command_handler.connected() or not self.streams:
# Don't bother if nothing is listening. We still need to advance
# the stream tokens otherwise they'll fall behind forever
for stream in self.streams:
stream.discard_updates_and_advance()
return
+ # We check up front to see if anything has actually changed, as we get
+ # poked because of changes that happened on other instances.
+ if all(
+ stream.last_token == stream.current_token(self._instance_name)
+ for stream in self.streams
+ ):
+ return
+
+ # If there are updates then we need to set this even if we're already
+ # looping, as the loop needs to know that he might need to loop again.
self.pending_updates = True
if self.is_looping:
@@ -136,6 +165,8 @@ class ReplicationStreamer:
self._replication_torture_level / 1000.0
)
+ last_token = stream.last_token
+
logger.debug(
"Getting stream: %s: %s -> %s",
stream.NAME,
@@ -159,6 +190,30 @@ class ReplicationStreamer:
)
stream_updates_counter.labels(stream.NAME).inc(len(updates))
+ else:
+ # The token has advanced but there is no data to
+ # send, so we send a `POSITION` to inform other
+ # workers of the updated position.
+ if stream.NAME == EventsStream.NAME:
+ # XXX: We only do this for the EventStream as it
+ # turns out that e.g. account data streams share
+ # their "current token" with each other, meaning
+ # that it is *not* safe to send a POSITION.
+ logger.info(
+ "Sending position: %s -> %s",
+ stream.NAME,
+ current_token,
+ )
+ self.command_handler.send_command(
+ PositionCommand(
+ stream.NAME,
+ self._instance_name,
+ last_token,
+ current_token,
+ )
+ )
+ continue
+
# Some streams return multiple rows with the same stream IDs,
# we need to make sure they get sent out in batches. We do
# this by setting the current token to all but the last of
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 54dccd15a6..61b282ab2d 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -240,13 +240,18 @@ class BackfillStream(Stream):
ROW_TYPE = BackfillStreamRow
def __init__(self, hs):
- store = hs.get_datastore()
+ self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(store.get_current_backfill_token),
- store.get_all_new_backfill_event_rows,
+ self._current_token,
+ self.store.get_all_new_backfill_event_rows,
)
+ def _current_token(self, instance_name: str) -> int:
+ # The backfill stream over replication operates on *positive* numbers,
+ # which means we need to negate it.
+ return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
+
class PresenceStream(Stream):
PresenceStreamRow = namedtuple(
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index ccc7ca30d8..86a62b71eb 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -15,12 +15,15 @@
# limitations under the License.
import heapq
from collections.abc import Iterable
-from typing import List, Tuple, Type
+from typing import TYPE_CHECKING, List, Optional, Tuple, Type
import attr
from ._base import Stream, StreamUpdateResult, Token
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
"""Handling of the 'events' replication stream
This stream contains rows of various types. Each row therefore contains a 'type'
@@ -81,12 +84,14 @@ class BaseEventsStreamRow:
class EventsStreamEventRow(BaseEventsStreamRow):
TypeId = "ev"
- event_id = attr.ib() # str
- room_id = attr.ib() # str
- type = attr.ib() # str
- state_key = attr.ib() # str, optional
- redacts = attr.ib() # str, optional
- relates_to = attr.ib() # str, optional
+ event_id = attr.ib(type=str)
+ room_id = attr.ib(type=str)
+ type = attr.ib(type=str)
+ state_key = attr.ib(type=Optional[str])
+ redacts = attr.ib(type=Optional[str])
+ relates_to = attr.ib(type=Optional[str])
+ membership = attr.ib(type=Optional[str])
+ rejected = attr.ib(type=bool)
@attr.s(slots=True, frozen=True)
@@ -113,7 +118,7 @@ class EventsStream(Stream):
NAME = "events"
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -155,7 +160,7 @@ class EventsStream(Stream):
# now we fetch up to that many rows from the events table
event_rows = await self._store.get_all_new_forward_event_rows(
- from_token, current_token, target_row_count
+ instance_name, from_token, current_token, target_row_count
) # type: List[Tuple]
# we rely on get_all_new_forward_event_rows strictly honouring the limit, so
@@ -180,7 +185,7 @@ class EventsStream(Stream):
upper_limit,
state_rows_limited,
) = await self._store.get_all_updated_current_state_deltas(
- from_token, upper_limit, target_row_count
+ instance_name, from_token, upper_limit, target_row_count
)
limited = limited or state_rows_limited
@@ -189,7 +194,7 @@ class EventsStream(Stream):
# not to bother with the limit.
ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
- from_token, upper_limit
+ instance_name, from_token, upper_limit
) # type: List[Tuple]
# we now need to turn the raw database rows returned into tuples suitable
|