summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/replication/tcp/handler.py73
-rw-r--r--synapse/replication/tcp/streams/_base.py3
-rw-r--r--synapse/storage/data_stores/main/push_rule.py40
3 files changed, 67 insertions, 49 deletions
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 8ec0119697..dd71d1bc34 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -189,16 +189,34 @@ class ReplicationCommandHandler:
             logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
             raise
 
-        if cmd.token is None or stream_name not in self._streams_connected:
-            # I.e. either this is part of a batch of updates for this stream (in
-            # which case batch until we get an update for the stream with a non
-            # None token) or we're currently connecting so we queue up rows.
-            self._pending_batches.setdefault(stream_name, []).append(row)
-        else:
-            # Check if this is the last of a batch of updates
-            rows = self._pending_batches.pop(stream_name, [])
-            rows.append(row)
-            await self.on_rdata(stream_name, cmd.token, rows)
+        # We linearize here for two reasons:
+        #   1. so we don't try and concurrently handle multiple rows for the
+        #      same stream, and
+        #   2. so we don't race with getting a POSITION command and fetching
+        #      missing RDATA.
+        with await self._position_linearizer.queue(cmd.stream_name):
+            if stream_name not in self._streams_connected:
+                # If the stream isn't marked as connected then we haven't seen a
+                # `POSITION` command yet, and so we may have missed some rows.
+                # Let's drop the row for now, on the assumption we'll receive a
+                # `POSITION` soon and we'll catch up correctly then.
+                logger.warning(
+                    "Discarding RDATA for unconnected stream %s -> %s",
+                    stream_name,
+                    cmd.token,
+                )
+                return
+
+            if cmd.token is None:
+                # I.e. this is part of a batch of updates for this stream (in
+                # which case batch until we get an update for the stream with a non
+                # None token).
+                self._pending_batches.setdefault(stream_name, []).append(row)
+            else:
+                # Check if this is the last of a batch of updates
+                rows = self._pending_batches.pop(stream_name, [])
+                rows.append(row)
+                await self.on_rdata(stream_name, cmd.token, rows)
 
     async def on_rdata(self, stream_name: str, token: int, rows: list):
         """Called to handle a batch of replication data with a given stream token.
@@ -221,12 +239,13 @@ class ReplicationCommandHandler:
         # We protect catching up with a linearizer in case the replication
         # connection reconnects under us.
         with await self._position_linearizer.queue(cmd.stream_name):
-            # We're about to go and catch up with the stream, so mark as connecting
-            # to stop RDATA being handled at the same time by removing stream from
-            # list of connected streams. We also clear any batched up RDATA from
-            # before we got the POSITION.
+            # We're about to go and catch up with the stream, so remove from set
+            # of connected streams.
             self._streams_connected.discard(cmd.stream_name)
-            self._pending_batches.clear()
+
+            # We clear the pending batches for the stream as the fetching of the
+            # missing updates below will fetch all rows in the batch.
+            self._pending_batches.pop(cmd.stream_name, [])
 
             # Find where we previously streamed up to.
             current_token = self._replication_data_handler.get_streams_to_replicate().get(
@@ -239,12 +258,17 @@ class ReplicationCommandHandler:
                 )
                 return
 
-            # Fetch all updates between then and now.
-            limited = True
-            while limited:
-                updates, current_token, limited = await stream.get_updates_since(
-                    current_token, cmd.token
-                )
+            # If the position token matches our current token then we're up to
+            # date and there's nothing to do. Otherwise, fetch all updates
+            # between then and now.
+            missing_updates = cmd.token != current_token
+            while missing_updates:
+                (
+                    updates,
+                    current_token,
+                    missing_updates,
+                ) = await stream.get_updates_since(current_token, cmd.token)
+
                 if updates:
                     await self.on_rdata(
                         cmd.stream_name,
@@ -255,13 +279,6 @@ class ReplicationCommandHandler:
             # We've now caught up to position sent to us, notify handler.
             await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)
 
-            # Handle any RDATA that came in while we were catching up.
-            rows = self._pending_batches.pop(cmd.stream_name, [])
-            if rows:
-                await self._replication_data_handler.on_rdata(
-                    cmd.stream_name, rows[-1].token, rows
-                )
-
             self._streams_connected.add(cmd.stream_name)
 
     async def on_SYNC(self, cmd: SyncCommand):
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index c14dff6c64..f56a0fd4b5 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -168,12 +168,13 @@ def make_http_update_function(
     async def update_function(
         from_token: int, upto_token: int, limit: int
     ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
-        return await client(
+        result = await client(
             stream_name=stream_name,
             from_token=from_token,
             upto_token=upto_token,
             limit=limit,
         )
+        return result["updates"], result["upto_token"], result["limited"]
 
     return update_function
 
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index 46f9bda773..b3faafa0a4 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -334,6 +334,26 @@ class PushRulesWorkerStore(
             results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
         return results
 
+    def get_all_push_rule_updates(self, last_id, current_id, limit):
+        """Get all the push rules changes that have happend on the server"""
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_push_rule_updates_txn(txn):
+            sql = (
+                "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
+                " op, priority_class, priority, conditions, actions"
+                " FROM push_rules_stream"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
+
+        return self.db.runInteraction(
+            "get_all_push_rule_updates", get_all_push_rule_updates_txn
+        )
+
 
 class PushRuleStore(PushRulesWorkerStore):
     @defer.inlineCallbacks
@@ -685,26 +705,6 @@ class PushRuleStore(PushRulesWorkerStore):
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
 
-    def get_all_push_rule_updates(self, last_id, current_id, limit):
-        """Get all the push rules changes that have happend on the server"""
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_push_rule_updates_txn(txn):
-            sql = (
-                "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
-                " op, priority_class, priority, conditions, actions"
-                " FROM push_rules_stream"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, current_id, limit))
-            return txn.fetchall()
-
-        return self.db.runInteraction(
-            "get_all_push_rule_updates", get_all_push_rule_updates_txn
-        )
-
     def get_push_rules_stream_token(self):
         """Get the position of the push rules stream.
         Returns a pair of a stream id for the push_rules stream and the