summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7286.misc1
-rw-r--r--synapse/replication/tcp/handler.py73
-rw-r--r--synapse/util/caches/stream_change_cache.py3
3 files changed, 72 insertions, 5 deletions
diff --git a/changelog.d/7286.misc b/changelog.d/7286.misc
new file mode 100644
index 0000000000..676f285377
--- /dev/null
+++ b/changelog.d/7286.misc
@@ -0,0 +1 @@
+Move catchup of replication streams logic to worker.
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 2f5a299141..e32e68e8c4 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,7 +15,18 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Callable, Dict, List, Optional, Set
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    TypeVar,
+)
 
 from prometheus_client import Counter
 
@@ -268,11 +279,14 @@ class ReplicationCommandHandler:
                     missing_updates,
                 ) = await stream.get_updates_since(current_token, cmd.token)
 
-                if updates:
+                # TODO: add some tests for this
+
+                # Some streams return multiple rows with the same stream IDs,
+                # which need to be processed in batches.
+
+                for token, rows in _batch_updates(updates):
                     await self.on_rdata(
-                        cmd.stream_name,
-                        current_token,
-                        [stream.parse_row(update[1]) for update in updates],
+                        cmd.stream_name, token, [stream.parse_row(row) for row in rows],
                     )
 
             # We've now caught up to position sent to us, notify handler.
@@ -404,3 +418,52 @@ class ReplicationCommandHandler:
         We need to check if the client is interested in the stream or not
         """
         self.send_command(RdataCommand(stream_name, token, data))
+
+
+UpdateToken = TypeVar("UpdateToken")
+UpdateRow = TypeVar("UpdateRow")
+
+
+def _batch_updates(
+    updates: Iterable[Tuple[UpdateToken, UpdateRow]]
+) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]:
+    """Collect stream updates with the same token together
+
+    Given a series of updates returned by Stream.get_updates_since(), collects
+    the updates which share the same stream_id together.
+
+    For example:
+
+        [(1, a), (1, b), (2, c), (3, d), (3, e)]
+
+    becomes:
+
+        [
+            (1, [a, b]),
+            (2, [c]),
+            (3, [d, e]),
+        ]
+    """
+
+    update_iter = iter(updates)
+
+    first_update = next(update_iter, None)
+    if first_update is None:
+        # empty input
+        return
+
+    current_batch_token = first_update[0]
+    current_batch = [first_update[1]]
+
+    for token, row in update_iter:
+        if token != current_batch_token:
+            # different token to the previous row: flush the previous
+            # batch and start anew
+            yield current_batch_token, current_batch
+            current_batch_token = token
+            current_batch = []
+
+        current_batch.append(row)
+
+    # flush the final batch
+    yield current_batch_token, current_batch
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 235f64049c..c61d36a82e 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -126,6 +126,9 @@ class StreamChangeCache(object):
         """
         assert type(stream_pos) is int
 
+        # FIXME: add a sanity check here that we are not overwriting existing
+        # data in self._cache
+
         if stream_pos > self._earliest_known_stream_pos:
             old_pos = self._entity_to_key.get(entity, None)
             if old_pos is not None: