summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/generic_worker.py5
-rw-r--r--synapse/replication/tcp/commands.py4
-rw-r--r--synapse/replication/tcp/handler.py30
3 files changed, 33 insertions, 6 deletions
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f3ec2a34ec..53c488d211 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -738,6 +738,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
         except Exception:
             logger.exception("Error processing replication")
 
+    async def on_position(self, stream_name: str, instance_name: str, token: int):
+        await super().on_position(stream_name, instance_name, token)
+        # Also call on_rdata to ensure that stream positions are properly reset.
+        await self.on_rdata(stream_name, instance_name, token, [])
+
     def stop_pusher(self, user_id, app_id, pushkey):
         if not self.notify_pushers:
             return
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index c04f622816..ea5937a20c 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -149,7 +149,7 @@ class RdataCommand(Command):
 
 
 class PositionCommand(Command):
-    """Sent by the server to tell the client the stream postition without
+    """Sent by the server to tell the client the stream position without
     needing to send an RDATA.
 
     Format::
@@ -188,7 +188,7 @@ class ErrorCommand(_SimpleCommand):
 
 
 class PingCommand(_SimpleCommand):
-    """Sent by either side as a keep alive. The data is arbitary (often timestamp)
+    """Sent by either side as a keep alive. The data is arbitrary (often timestamp)
     """
 
     NAME = "PING"
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index cbcf46f3ae..e6a2e2598b 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -112,8 +112,8 @@ class ReplicationCommandHandler:
             "replication_position", clock=self._clock
         )
 
-        # Map of stream to batched updates. See RdataCommand for info on how
-        # batching works.
+        # Map of stream name to batched updates. See RdataCommand for info on
+        # how batching works.
         self._pending_batches = {}  # type: Dict[str, List[Any]]
 
         # The factory used to create connections.
@@ -123,7 +123,8 @@ class ReplicationCommandHandler:
         # outgoing replication commands to.)
         self._connections = []  # type: List[AbstractConnection]
 
-        # For each connection, the incoming streams that are coming from that connection
+        # For each connection, the incoming stream names that are coming from
+        # that connection.
         self._streams_by_connection = {}  # type: Dict[AbstractConnection, Set[str]]
 
         LaterGauge(
@@ -310,7 +311,28 @@ class ReplicationCommandHandler:
                 # Check if this is the last of a batch of updates
                 rows = self._pending_batches.pop(stream_name, [])
                 rows.append(row)
-                await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
+
+                stream = self._streams.get(stream_name)
+                if not stream:
+                    logger.error("Got RDATA for unknown stream: %s", stream_name)
+                    return
+
+                # Find where we previously streamed up to.
+                current_token = stream.current_token(cmd.instance_name)
+
+                # Discard this data if this token is earlier than the current
+                # position. Note that streams can be reset (in which case you
+                # expect an earlier token), but that must be preceded by a
+                # POSITION command.
+                if cmd.token <= current_token:
+                    logger.debug(
+                        "Discarding RDATA from stream %s at position %s before previous position %s",
+                        stream_name,
+                        cmd.token,
+                        current_token,
+                    )
+                else:
+                    await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
 
     async def on_rdata(
         self, stream_name: str, instance_name: str, token: int, rows: list