diff options
-rw-r--r-- | synapse/replication/tcp/commands.py | 32 | ||||
-rw-r--r-- | synapse/replication/tcp/handler.py | 4 | ||||
-rw-r--r-- | synapse/replication/tcp/resource.py | 13 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 7 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/events.py | 18 | ||||
-rw-r--r-- | synapse/storage/databases/main/events.py | 9 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 53 |
7 files changed, 130 insertions, 6 deletions
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 8cd47770c1..3bc06d59d5 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -171,6 +171,37 @@ class PositionCommand(Command): return " ".join((self.stream_name, self.instance_name, str(self.token))) +class PersistedToCommand(Command): + """Sent by writers to inform others that it has persisted up to the included + token. + + The included `token` will *not* have been persisted by the instance. + + Format:: + + PERSISTED_TO <stream_name> <instance_name> <token> + + On receipt the client should mark that the given instances has persisted + everything up to the given token. Note: this does *not* mean that other + instances have also persisted all their rows up to that point. + """ + + NAME = "PERSISTED_TO" + + def __init__(self, stream_name, instance_name, token): + self.stream_name = stream_name + self.instance_name = instance_name + self.token = token + + @classmethod + def from_line(cls, line): + stream_name, instance_name, token = line.split(" ", 2) + return cls(stream_name, instance_name, int(token)) + + def to_line(self): + return " ".join((self.stream_name, self.instance_name, str(self.token))) + + class ErrorCommand(_SimpleCommand): """Sent by either side if there was an ERROR. The data is a string describing the error. @@ -405,6 +436,7 @@ _COMMANDS = ( UserIpCommand, RemoteServerUpCommand, ClearUserSyncsCommand, + PersistedToCommand, ) # type: Tuple[Type[Command], ...] # Map of command name to command type. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b323841f73..08049fe2e0 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -47,6 +47,7 @@ from synapse.replication.tcp.commands import ( ReplicateCommand, UserIpCommand, UserSyncCommand, + PersistedToCommand, ) from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.streams import ( @@ -387,6 +388,9 @@ class ReplicationCommandHandler: assert self._server_notices_sender is not None await self._server_notices_sender.on_user_ip(cmd.user_id) + def on_PERSISTED_TO(self, conn: AbstractConnection, cmd: PersistedToCommand): + pass + def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): if cmd.instance_name == self._instance_name: # Ignore RDATA that are just our own echoes diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 687984e7a8..623d7fff3f 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -24,6 +24,7 @@ from twisted.internet.protocol import Factory from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol +from synapse.replication.tcp.streams import EventsStream from synapse.util.metrics import Measure stream_updates_counter = Counter( @@ -84,6 +85,9 @@ class ReplicationStreamer: # Set of streams to replicate. self.streams = self.command_handler.get_streams_to_replicate() + if self.streams: + self.clock.looping_call(self.on_notifier_poke, 1000.0) + def on_notifier_poke(self): """Checks if there is actually any new data and sends it to the connections if there are. @@ -126,9 +130,7 @@ class ReplicationStreamer: random.shuffle(all_streams) for stream in all_streams: - if stream.last_token == stream.current_token( - self._instance_name - ): + if not stream.has_updates(): continue if self._replication_torture_level: @@ -174,6 +176,11 @@ class ReplicationStreamer: except Exception: logger.exception("Failed to replicate") + # for command in stream.extra_commands( + # sent_updates=bool(updates) + # ): + # self.command_handler.send_command(command) + logger.debug("No more pending updates, breaking poke loop") finally: self.pending_updates = False diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 54dccd15a6..f3ea34f886 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -31,6 +31,7 @@ from typing import ( import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates +from synapse.replication.tcp.commands import Command if TYPE_CHECKING: import synapse.server @@ -187,6 +188,12 @@ class Stream: ) return updates, upto_token, limited + def has_updates(self) -> bool: + return self.current_token(self.local_instance_name) != self.last_token + + def extra_commands(self, sent_updates: bool) -> List[Command]: + return [] + def current_token_without_instance( current_token: Callable[[], int] diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index ccc7ca30d8..1aa7ba3da6 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,8 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token +from synapse.replication.tcp.streams._base import Stream, StreamUpdateResult, Token +from synapse.replication.tcp.commands import Command, PersistedToCommand """Handling of the 'events' replication stream @@ -222,3 +223,18 @@ class EventsStream(Stream): (typ, data) = row data = TypeToRow[typ].from_data(data) return EventsStreamRow(typ, data) + + def has_updates(self) -> bool: + return True + + def extra_commands(self, sent_updates: bool) -> List[Command]: + if sent_updates: + return [] + + return [ + PersistedToCommand( + self.NAME, + self.local_instance_name, + self._store._stream_id_gen.get_max_persisted_position_for_self(), + ) + ] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 18def01f50..788158199c 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -178,6 +178,8 @@ class PersistEventsStore: ) persist_event_counter.inc(len(events_and_contexts)) + logger.debug("Finished persisting 1") + if not backfilled: # backfilled events have negative stream orderings, so we don't # want to set the event_persisted_position to that. @@ -185,6 +187,8 @@ class PersistEventsStore: events_and_contexts[-1][0].internal_metadata.stream_ordering ) + logger.debug("Finished persisting 2") + for event, context in events_and_contexts: if context.app_service: origin_type = "local" @@ -198,6 +202,8 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() + logger.debug("Finished persisting 3") + for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) @@ -206,6 +212,9 @@ class PersistEventsStore: (room_id,), list(latest_event_ids) ) + logger.debug("Finished persisting 4") + logger.debug("Finished persisting 5") + async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: """Filter the supplied list of event_ids to get those which are prev_events of existing (non-outlier/rejected) events. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 4fd7573e26..f09a68e440 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -217,6 +217,7 @@ class MultiWriterIdGenerator: self._instance_name = instance_name self._positive = positive self._writers = writers + self._sequence_name = sequence_name self._return_factor = 1 if positive else -1 # We lock as some functions may be called from DB threads. @@ -227,6 +228,8 @@ class MultiWriterIdGenerator: # return them. self._current_positions = {} # type: Dict[str, int] + self._max_persisted_positions = dict(self._current_positions) + # Set of local IDs that we're still processing. The current position # should be less than the minimum of this set (if not empty). self._unfinished_ids = set() # type: Set[int] @@ -404,6 +407,12 @@ class MultiWriterIdGenerator: current position if possible. """ + logger.debug( + "Mark as finished 1 _current_positions %s: %s", + self._sequence_name, + self._current_positions, + ) + with self._lock: self._unfinished_ids.discard(next_id) self._finished_ids.add(next_id) @@ -439,6 +448,16 @@ class MultiWriterIdGenerator: if new_cur: curr = self._current_positions.get(self._instance_name, 0) self._current_positions[self._instance_name] = max(curr, new_cur) + self._max_persisted_positions[self._instance_name] = max( + self._current_positions[self._instance_name], + self._max_persisted_positions.get(self._instance_name, 0), + ) + + logger.debug( + "Mark as finished _current_positions %s: %s", + self._sequence_name, + self._current_positions, + ) self._add_persisted_position(next_id) @@ -454,6 +473,11 @@ class MultiWriterIdGenerator: """ with self._lock: + logger.debug( + "get_current_token_for_writer %s: %s", + self._sequence_name, + self._current_positions, + ) return self._return_factor * self._current_positions.get(instance_name, 0) def get_positions(self) -> Dict[str, int]: @@ -478,6 +502,12 @@ class MultiWriterIdGenerator: new_id, self._current_positions.get(instance_name, 0) ) + self._max_persisted_positions[instance_name] = max( + new_id, + self._current_positions.get(instance_name, 0), + self._max_persisted_positions.get(instance_name, 0), + ) + self._add_persisted_position(new_id) def get_persisted_upto_position(self) -> int: @@ -492,10 +522,29 @@ class MultiWriterIdGenerator: with self._lock: return self._return_factor * self._persisted_upto_position + def get_max_persisted_position_for_self(self) -> int: + with self._lock: + if self._unfinished_ids: + return self.get_current_token_for_writer(self._instance_name) + + return self._return_factor * max( + self._current_positions.values(), default=1 + ) + + def advance_persisted_to(self, instance_name: str, new_id: int): + new_id *= self._return_factor + + with self._lock: + self._max_persisted_positions[instance_name] = max( + new_id, + self._current_positions.get(instance_name, 0), + self._max_persisted_positions.get(instance_name, 0), + ) + def _add_persisted_position(self, new_id: int): """Record that we have persisted a position. - This is used to keep the `_current_positions` up to date. + This is used to keep the `_persisted_upto_position` up to date. """ # We require that the lock is locked by caller @@ -506,7 +555,7 @@ class MultiWriterIdGenerator: # We move the current min position up if the minimum current positions # of all instances is higher (since by definition all positions less # that that have been persisted). - min_curr = min(self._current_positions.values(), default=0) + min_curr = min(self._max_persisted_positions.values(), default=0) self._persisted_upto_position = max(min_curr, self._persisted_upto_position) # We now iterate through the seen positions, discarding those that are |