| diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 8ec0119697..dd71d1bc34 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -189,16 +189,34 @@ class ReplicationCommandHandler:
             logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
             raise
 
-        if cmd.token is None or stream_name not in self._streams_connected:
-            # I.e. either this is part of a batch of updates for this stream (in
-            # which case batch until we get an update for the stream with a non
-            # None token) or we're currently connecting so we queue up rows.
-            self._pending_batches.setdefault(stream_name, []).append(row)
-        else:
-            # Check if this is the last of a batch of updates
-            rows = self._pending_batches.pop(stream_name, [])
-            rows.append(row)
-            await self.on_rdata(stream_name, cmd.token, rows)
+        # We linearize here for two reasons:
+        #   1. so we don't try and concurrently handle multiple rows for the
+        #      same stream, and
+        #   2. so we don't race with getting a POSITION command and fetching
+        #      missing RDATA.
+        with await self._position_linearizer.queue(cmd.stream_name):
+            if stream_name not in self._streams_connected:
+                # If the stream isn't marked as connected then we haven't seen a
+                # `POSITION` command yet, and so we may have missed some rows.
+                # Let's drop the row for now, on the assumption we'll receive a
+                # `POSITION` soon and we'll catch up correctly then.
+                logger.warning(
+                    "Discarding RDATA for unconnected stream %s -> %s",
+                    stream_name,
+                    cmd.token,
+                )
+                return
+
+            if cmd.token is None:
+                # I.e. this is part of a batch of updates for this stream (in
+                # which case batch until we get an update for the stream with a non
+                # None token).
+                self._pending_batches.setdefault(stream_name, []).append(row)
+            else:
+                # Check if this is the last of a batch of updates
+                rows = self._pending_batches.pop(stream_name, [])
+                rows.append(row)
+                await self.on_rdata(stream_name, cmd.token, rows)
 
     async def on_rdata(self, stream_name: str, token: int, rows: list):
         """Called to handle a batch of replication data with a given stream token.
@@ -221,12 +239,13 @@ class ReplicationCommandHandler:
         # We protect catching up with a linearizer in case the replication
         # connection reconnects under us.
         with await self._position_linearizer.queue(cmd.stream_name):
-            # We're about to go and catch up with the stream, so mark as connecting
-            # to stop RDATA being handled at the same time by removing stream from
-            # list of connected streams. We also clear any batched up RDATA from
-            # before we got the POSITION.
+            # We're about to go and catch up with the stream, so remove from set
+            # of connected streams.
             self._streams_connected.discard(cmd.stream_name)
-            self._pending_batches.clear()
+
+            # We clear the pending batches for the stream as the fetching of the
+            # missing updates below will fetch all rows in the batch.
+            self._pending_batches.pop(cmd.stream_name, [])
 
             # Find where we previously streamed up to.
             current_token = self._replication_data_handler.get_streams_to_replicate().get(
@@ -239,12 +258,17 @@ class ReplicationCommandHandler:
                 )
                 return
 
-            # Fetch all updates between then and now.
-            limited = True
-            while limited:
-                updates, current_token, limited = await stream.get_updates_since(
-                    current_token, cmd.token
-                )
+            # If the position token matches our current token then we're up to
+            # date and there's nothing to do. Otherwise, fetch all updates
+            # between then and now.
+            missing_updates = cmd.token != current_token
+            while missing_updates:
+                (
+                    updates,
+                    current_token,
+                    missing_updates,
+                ) = await stream.get_updates_since(current_token, cmd.token)
+
                 if updates:
                     await self.on_rdata(
                         cmd.stream_name,
@@ -255,13 +279,6 @@ class ReplicationCommandHandler:
             # We've now caught up to position sent to us, notify handler.
             await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)
 
-            # Handle any RDATA that came in while we were catching up.
-            rows = self._pending_batches.pop(cmd.stream_name, [])
-            if rows:
-                await self._replication_data_handler.on_rdata(
-                    cmd.stream_name, rows[-1].token, rows
-                )
-
             self._streams_connected.add(cmd.stream_name)
 
     async def on_SYNC(self, cmd: SyncCommand):
 |