diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 8cd47770c1..3bc06d59d5 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -171,6 +171,37 @@ class PositionCommand(Command):
return " ".join((self.stream_name, self.instance_name, str(self.token)))
+class PersistedToCommand(Command):
+ """Sent by writers to inform others that it has persisted up to the included
+ token.
+
+ The included `token` will *not* have been persisted by the instance.
+
+ Format::
+
+ PERSISTED_TO <stream_name> <instance_name> <token>
+
+ On receipt the client should mark that the given instances has persisted
+ everything up to the given token. Note: this does *not* mean that other
+ instances have also persisted all their rows up to that point.
+ """
+
+ NAME = "PERSISTED_TO"
+
+ def __init__(self, stream_name, instance_name, token):
+ self.stream_name = stream_name
+ self.instance_name = instance_name
+ self.token = token
+
+ @classmethod
+ def from_line(cls, line):
+ stream_name, instance_name, token = line.split(" ", 2)
+ return cls(stream_name, instance_name, int(token))
+
+ def to_line(self):
+ return " ".join((self.stream_name, self.instance_name, str(self.token)))
+
+
class ErrorCommand(_SimpleCommand):
"""Sent by either side if there was an ERROR. The data is a string describing
the error.
@@ -405,6 +436,7 @@ _COMMANDS = (
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
+ PersistedToCommand,
) # type: Tuple[Type[Command], ...]
# Map of command name to command type.
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b323841f73..08049fe2e0 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -47,6 +47,7 @@ from synapse.replication.tcp.commands import (
ReplicateCommand,
UserIpCommand,
UserSyncCommand,
+ PersistedToCommand,
)
from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.streams import (
@@ -387,6 +388,9 @@ class ReplicationCommandHandler:
assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id)
+ def on_PERSISTED_TO(self, conn: AbstractConnection, cmd: PersistedToCommand):
+ pass
+
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 687984e7a8..623d7fff3f 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -24,6 +24,7 @@ from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
+from synapse.replication.tcp.streams import EventsStream
from synapse.util.metrics import Measure
stream_updates_counter = Counter(
@@ -84,6 +85,9 @@ class ReplicationStreamer:
# Set of streams to replicate.
self.streams = self.command_handler.get_streams_to_replicate()
+ if self.streams:
+ self.clock.looping_call(self.on_notifier_poke, 1000.0)
+
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
connections if there are.
@@ -126,9 +130,7 @@ class ReplicationStreamer:
random.shuffle(all_streams)
for stream in all_streams:
- if stream.last_token == stream.current_token(
- self._instance_name
- ):
+ if not stream.has_updates():
continue
if self._replication_torture_level:
@@ -174,6 +176,11 @@ class ReplicationStreamer:
except Exception:
logger.exception("Failed to replicate")
+ # for command in stream.extra_commands(
+ # sent_updates=bool(updates)
+ # ):
+ # self.command_handler.send_command(command)
+
logger.debug("No more pending updates, breaking poke loop")
finally:
self.pending_updates = False
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 54dccd15a6..f3ea34f886 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -31,6 +31,7 @@ from typing import (
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
+from synapse.replication.tcp.commands import Command
if TYPE_CHECKING:
import synapse.server
@@ -187,6 +188,12 @@ class Stream:
)
return updates, upto_token, limited
+ def has_updates(self) -> bool:
+ return self.current_token(self.local_instance_name) != self.last_token
+
+ def extra_commands(self, sent_updates: bool) -> List[Command]:
+ return []
+
def current_token_without_instance(
current_token: Callable[[], int]
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index ccc7ca30d8..1aa7ba3da6 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,8 @@ from typing import List, Tuple, Type
import attr
-from ._base import Stream, StreamUpdateResult, Token
+from synapse.replication.tcp.streams._base import Stream, StreamUpdateResult, Token
+from synapse.replication.tcp.commands import Command, PersistedToCommand
"""Handling of the 'events' replication stream
@@ -222,3 +223,18 @@ class EventsStream(Stream):
(typ, data) = row
data = TypeToRow[typ].from_data(data)
return EventsStreamRow(typ, data)
+
+ def has_updates(self) -> bool:
+ return True
+
+ def extra_commands(self, sent_updates: bool) -> List[Command]:
+ if sent_updates:
+ return []
+
+ return [
+ PersistedToCommand(
+ self.NAME,
+ self.local_instance_name,
+ self._store._stream_id_gen.get_max_persisted_position_for_self(),
+ )
+ ]
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 18def01f50..788158199c 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -178,6 +178,8 @@ class PersistEventsStore:
)
persist_event_counter.inc(len(events_and_contexts))
+ logger.debug("Finished persisting 1")
+
if not backfilled:
# backfilled events have negative stream orderings, so we don't
# want to set the event_persisted_position to that.
@@ -185,6 +187,8 @@ class PersistEventsStore:
events_and_contexts[-1][0].internal_metadata.stream_ordering
)
+ logger.debug("Finished persisting 2")
+
for event, context in events_and_contexts:
if context.app_service:
origin_type = "local"
@@ -198,6 +202,8 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
+ logger.debug("Finished persisting 3")
+
for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state)
@@ -206,6 +212,9 @@ class PersistEventsStore:
(room_id,), list(latest_event_ids)
)
+ logger.debug("Finished persisting 4")
+ logger.debug("Finished persisting 5")
+
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
"""Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events.
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 4fd7573e26..f09a68e440 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -217,6 +217,7 @@ class MultiWriterIdGenerator:
self._instance_name = instance_name
self._positive = positive
self._writers = writers
+ self._sequence_name = sequence_name
self._return_factor = 1 if positive else -1
# We lock as some functions may be called from DB threads.
@@ -227,6 +228,8 @@ class MultiWriterIdGenerator:
# return them.
self._current_positions = {} # type: Dict[str, int]
+ self._max_persisted_positions = dict(self._current_positions)
+
# Set of local IDs that we're still processing. The current position
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
@@ -404,6 +407,12 @@ class MultiWriterIdGenerator:
current position if possible.
"""
+ logger.debug(
+ "Mark as finished 1 _current_positions %s: %s",
+ self._sequence_name,
+ self._current_positions,
+ )
+
with self._lock:
self._unfinished_ids.discard(next_id)
self._finished_ids.add(next_id)
@@ -439,6 +448,16 @@ class MultiWriterIdGenerator:
if new_cur:
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, new_cur)
+ self._max_persisted_positions[self._instance_name] = max(
+ self._current_positions[self._instance_name],
+ self._max_persisted_positions.get(self._instance_name, 0),
+ )
+
+ logger.debug(
+ "Mark as finished _current_positions %s: %s",
+ self._sequence_name,
+ self._current_positions,
+ )
self._add_persisted_position(next_id)
@@ -454,6 +473,11 @@ class MultiWriterIdGenerator:
"""
with self._lock:
+ logger.debug(
+ "get_current_token_for_writer %s: %s",
+ self._sequence_name,
+ self._current_positions,
+ )
return self._return_factor * self._current_positions.get(instance_name, 0)
def get_positions(self) -> Dict[str, int]:
@@ -478,6 +502,12 @@ class MultiWriterIdGenerator:
new_id, self._current_positions.get(instance_name, 0)
)
+ self._max_persisted_positions[instance_name] = max(
+ new_id,
+ self._current_positions.get(instance_name, 0),
+ self._max_persisted_positions.get(instance_name, 0),
+ )
+
self._add_persisted_position(new_id)
def get_persisted_upto_position(self) -> int:
@@ -492,10 +522,29 @@ class MultiWriterIdGenerator:
with self._lock:
return self._return_factor * self._persisted_upto_position
+ def get_max_persisted_position_for_self(self) -> int:
+ with self._lock:
+ if self._unfinished_ids:
+ return self.get_current_token_for_writer(self._instance_name)
+
+ return self._return_factor * max(
+ self._current_positions.values(), default=1
+ )
+
+ def advance_persisted_to(self, instance_name: str, new_id: int):
+ new_id *= self._return_factor
+
+ with self._lock:
+ self._max_persisted_positions[instance_name] = max(
+ new_id,
+ self._current_positions.get(instance_name, 0),
+ self._max_persisted_positions.get(instance_name, 0),
+ )
+
def _add_persisted_position(self, new_id: int):
"""Record that we have persisted a position.
- This is used to keep the `_current_positions` up to date.
+ This is used to keep the `_persisted_upto_position` up to date.
"""
# We require that the lock is locked by caller
@@ -506,7 +555,7 @@ class MultiWriterIdGenerator:
# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
- min_curr = min(self._current_positions.values(), default=0)
+ min_curr = min(self._max_persisted_positions.values(), default=0)
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
# We now iterate through the seen positions, discarding those that are
|