summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/replication/tcp/commands.py32
-rw-r--r--synapse/replication/tcp/handler.py4
-rw-r--r--synapse/replication/tcp/resource.py13
-rw-r--r--synapse/replication/tcp/streams/_base.py7
-rw-r--r--synapse/replication/tcp/streams/events.py18
-rw-r--r--synapse/storage/databases/main/events.py9
-rw-r--r--synapse/storage/util/id_generators.py53
7 files changed, 130 insertions, 6 deletions
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 8cd47770c1..3bc06d59d5 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -171,6 +171,37 @@ class PositionCommand(Command):
         return " ".join((self.stream_name, self.instance_name, str(self.token)))
 
 
+class PersistedToCommand(Command):
+    """Sent by writers to inform others that it has persisted up to the included
+    token.
+
+    The included `token` will *not* have been persisted by the instance.
+
+    Format::
+
+        PERSISTED_TO <stream_name> <instance_name> <token>
+
+    On receipt the client should mark that the given instances has persisted
+    everything up to the given token. Note: this does *not* mean that other
+    instances have also persisted all their rows up to that point.
+    """
+
+    NAME = "PERSISTED_TO"
+
+    def __init__(self, stream_name, instance_name, token):
+        self.stream_name = stream_name
+        self.instance_name = instance_name
+        self.token = token
+
+    @classmethod
+    def from_line(cls, line):
+        stream_name, instance_name, token = line.split(" ", 2)
+        return cls(stream_name, instance_name, int(token))
+
+    def to_line(self):
+        return " ".join((self.stream_name, self.instance_name, str(self.token)))
+
+
 class ErrorCommand(_SimpleCommand):
     """Sent by either side if there was an ERROR. The data is a string describing
     the error.
@@ -405,6 +436,7 @@ _COMMANDS = (
     UserIpCommand,
     RemoteServerUpCommand,
     ClearUserSyncsCommand,
+    PersistedToCommand,
 )  # type: Tuple[Type[Command], ...]
 
 # Map of command name to command type.
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b323841f73..08049fe2e0 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -47,6 +47,7 @@ from synapse.replication.tcp.commands import (
     ReplicateCommand,
     UserIpCommand,
     UserSyncCommand,
+    PersistedToCommand,
 )
 from synapse.replication.tcp.protocol import AbstractConnection
 from synapse.replication.tcp.streams import (
@@ -387,6 +388,9 @@ class ReplicationCommandHandler:
         assert self._server_notices_sender is not None
         await self._server_notices_sender.on_user_ip(cmd.user_id)
 
+    def on_PERSISTED_TO(self, conn: AbstractConnection, cmd: PersistedToCommand):
+        pass
+
     def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
         if cmd.instance_name == self._instance_name:
             # Ignore RDATA that are just our own echoes
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 687984e7a8..623d7fff3f 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -24,6 +24,7 @@ from twisted.internet.protocol import Factory
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
+from synapse.replication.tcp.streams import EventsStream
 from synapse.util.metrics import Measure
 
 stream_updates_counter = Counter(
@@ -84,6 +85,9 @@ class ReplicationStreamer:
         # Set of streams to replicate.
         self.streams = self.command_handler.get_streams_to_replicate()
 
+        if self.streams:
+            self.clock.looping_call(self.on_notifier_poke, 1000.0)
+
     def on_notifier_poke(self):
         """Checks if there is actually any new data and sends it to the
         connections if there are.
@@ -126,9 +130,7 @@ class ReplicationStreamer:
                         random.shuffle(all_streams)
 
                     for stream in all_streams:
-                        if stream.last_token == stream.current_token(
-                            self._instance_name
-                        ):
+                        if not stream.has_updates():
                             continue
 
                         if self._replication_torture_level:
@@ -174,6 +176,11 @@ class ReplicationStreamer:
                             except Exception:
                                 logger.exception("Failed to replicate")
 
+                        # for command in stream.extra_commands(
+                        #     sent_updates=bool(updates)
+                        # ):
+                        #     self.command_handler.send_command(command)
+
             logger.debug("No more pending updates, breaking poke loop")
         finally:
             self.pending_updates = False
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 54dccd15a6..f3ea34f886 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -31,6 +31,7 @@ from typing import (
 import attr
 
 from synapse.replication.http.streams import ReplicationGetStreamUpdates
+from synapse.replication.tcp.commands import Command
 
 if TYPE_CHECKING:
     import synapse.server
@@ -187,6 +188,12 @@ class Stream:
         )
         return updates, upto_token, limited
 
+    def has_updates(self) -> bool:
+        return self.current_token(self.local_instance_name) != self.last_token
+
+    def extra_commands(self, sent_updates: bool) -> List[Command]:
+        return []
+
 
 def current_token_without_instance(
     current_token: Callable[[], int]
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index ccc7ca30d8..1aa7ba3da6 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,8 @@ from typing import List, Tuple, Type
 
 import attr
 
-from ._base import Stream, StreamUpdateResult, Token
+from synapse.replication.tcp.streams._base import Stream, StreamUpdateResult, Token
+from synapse.replication.tcp.commands import Command, PersistedToCommand
 
 """Handling of the 'events' replication stream
 
@@ -222,3 +223,18 @@ class EventsStream(Stream):
         (typ, data) = row
         data = TypeToRow[typ].from_data(data)
         return EventsStreamRow(typ, data)
+
+    def has_updates(self) -> bool:
+        return True
+
+    def extra_commands(self, sent_updates: bool) -> List[Command]:
+        if sent_updates:
+            return []
+
+        return [
+            PersistedToCommand(
+                self.NAME,
+                self.local_instance_name,
+                self._store._stream_id_gen.get_max_persisted_position_for_self(),
+            )
+        ]
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 18def01f50..788158199c 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -178,6 +178,8 @@ class PersistEventsStore:
             )
             persist_event_counter.inc(len(events_and_contexts))
 
+            logger.debug("Finished persisting 1")
+
             if not backfilled:
                 # backfilled events have negative stream orderings, so we don't
                 # want to set the event_persisted_position to that.
@@ -185,6 +187,8 @@ class PersistEventsStore:
                     events_and_contexts[-1][0].internal_metadata.stream_ordering
                 )
 
+            logger.debug("Finished persisting 2")
+
             for event, context in events_and_contexts:
                 if context.app_service:
                     origin_type = "local"
@@ -198,6 +202,8 @@ class PersistEventsStore:
 
                 event_counter.labels(event.type, origin_type, origin_entity).inc()
 
+            logger.debug("Finished persisting 3")
+
             for room_id, new_state in current_state_for_room.items():
                 self.store.get_current_state_ids.prefill((room_id,), new_state)
 
@@ -206,6 +212,9 @@ class PersistEventsStore:
                     (room_id,), list(latest_event_ids)
                 )
 
+            logger.debug("Finished persisting 4")
+        logger.debug("Finished persisting 5")
+
     async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
         """Filter the supplied list of event_ids to get those which are prev_events of
         existing (non-outlier/rejected) events.
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 4fd7573e26..f09a68e440 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -217,6 +217,7 @@ class MultiWriterIdGenerator:
         self._instance_name = instance_name
         self._positive = positive
         self._writers = writers
+        self._sequence_name = sequence_name
         self._return_factor = 1 if positive else -1
 
         # We lock as some functions may be called from DB threads.
@@ -227,6 +228,8 @@ class MultiWriterIdGenerator:
         # return them.
         self._current_positions = {}  # type: Dict[str, int]
 
+        self._max_persisted_positions = dict(self._current_positions)
+
         # Set of local IDs that we're still processing. The current position
         # should be less than the minimum of this set (if not empty).
         self._unfinished_ids = set()  # type: Set[int]
@@ -404,6 +407,12 @@ class MultiWriterIdGenerator:
         current position if possible.
         """
 
+        logger.debug(
+            "Mark as finished 1 _current_positions %s: %s",
+            self._sequence_name,
+            self._current_positions,
+        )
+
         with self._lock:
             self._unfinished_ids.discard(next_id)
             self._finished_ids.add(next_id)
@@ -439,6 +448,16 @@ class MultiWriterIdGenerator:
             if new_cur:
                 curr = self._current_positions.get(self._instance_name, 0)
                 self._current_positions[self._instance_name] = max(curr, new_cur)
+                self._max_persisted_positions[self._instance_name] = max(
+                    self._current_positions[self._instance_name],
+                    self._max_persisted_positions.get(self._instance_name, 0),
+                )
+
+            logger.debug(
+                "Mark as finished _current_positions %s: %s",
+                self._sequence_name,
+                self._current_positions,
+            )
 
             self._add_persisted_position(next_id)
 
@@ -454,6 +473,11 @@ class MultiWriterIdGenerator:
         """
 
         with self._lock:
+            logger.debug(
+                "get_current_token_for_writer %s: %s",
+                self._sequence_name,
+                self._current_positions,
+            )
             return self._return_factor * self._current_positions.get(instance_name, 0)
 
     def get_positions(self) -> Dict[str, int]:
@@ -478,6 +502,12 @@ class MultiWriterIdGenerator:
                 new_id, self._current_positions.get(instance_name, 0)
             )
 
+            self._max_persisted_positions[instance_name] = max(
+                new_id,
+                self._current_positions.get(instance_name, 0),
+                self._max_persisted_positions.get(instance_name, 0),
+            )
+
             self._add_persisted_position(new_id)
 
     def get_persisted_upto_position(self) -> int:
@@ -492,10 +522,29 @@ class MultiWriterIdGenerator:
         with self._lock:
             return self._return_factor * self._persisted_upto_position
 
+    def get_max_persisted_position_for_self(self) -> int:
+        with self._lock:
+            if self._unfinished_ids:
+                return self.get_current_token_for_writer(self._instance_name)
+
+            return self._return_factor * max(
+                self._current_positions.values(), default=1
+            )
+
+    def advance_persisted_to(self, instance_name: str, new_id: int):
+        new_id *= self._return_factor
+
+        with self._lock:
+            self._max_persisted_positions[instance_name] = max(
+                new_id,
+                self._current_positions.get(instance_name, 0),
+                self._max_persisted_positions.get(instance_name, 0),
+            )
+
     def _add_persisted_position(self, new_id: int):
         """Record that we have persisted a position.
 
-        This is used to keep the `_current_positions` up to date.
+        This is used to keep the `_persisted_upto_position` up to date.
         """
 
         # We require that the lock is locked by caller
@@ -506,7 +555,7 @@ class MultiWriterIdGenerator:
         # We move the current min position up if the minimum current positions
         # of all instances is higher (since by definition all positions less
         # that that have been persisted).
-        min_curr = min(self._current_positions.values(), default=0)
+        min_curr = min(self._max_persisted_positions.values(), default=0)
         self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
 
         # We now iterate through the seen positions, discarding those that are