diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 2f5a299141..e32e68e8c4 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,7 +15,18 @@
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List, Optional, Set
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+)
from prometheus_client import Counter
@@ -268,11 +279,14 @@ class ReplicationCommandHandler:
missing_updates,
) = await stream.get_updates_since(current_token, cmd.token)
- if updates:
+ # TODO: add some tests for this
+
+ # Some streams return multiple rows with the same stream IDs,
+ # which need to be processed in batches.
+
+ for token, rows in _batch_updates(updates):
await self.on_rdata(
- cmd.stream_name,
- current_token,
- [stream.parse_row(update[1]) for update in updates],
+ cmd.stream_name, token, [stream.parse_row(row) for row in rows],
)
# We've now caught up to position sent to us, notify handler.
@@ -404,3 +418,52 @@ class ReplicationCommandHandler:
We need to check if the client is interested in the stream or not
"""
self.send_command(RdataCommand(stream_name, token, data))
+
+
+UpdateToken = TypeVar("UpdateToken")
+UpdateRow = TypeVar("UpdateRow")
+
+
+def _batch_updates(
+ updates: Iterable[Tuple[UpdateToken, UpdateRow]]
+) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]:
+ """Collect stream updates with the same token together
+
+ Given a series of updates returned by Stream.get_updates_since(), collects
+ the updates which share the same stream_id together.
+
+ For example:
+
+ [(1, a), (1, b), (2, c), (3, d), (3, e)]
+
+ becomes:
+
+ [
+ (1, [a, b]),
+ (2, [c]),
+ (3, [d, e]),
+ ]
+ """
+
+ update_iter = iter(updates)
+
+ first_update = next(update_iter, None)
+ if first_update is None:
+ # empty input
+ return
+
+ current_batch_token = first_update[0]
+ current_batch = [first_update[1]]
+
+ for token, row in update_iter:
+ if token != current_batch_token:
+ # different token to the previous row: flush the previous
+ # batch and start anew
+ yield current_batch_token, current_batch
+ current_batch_token = token
+ current_batch = []
+
+ current_batch.append(row)
+
+ # flush the final batch
+ yield current_batch_token, current_batch
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 235f64049c..c61d36a82e 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -126,6 +126,9 @@ class StreamChangeCache(object):
"""
assert type(stream_pos) is int
+ # FIXME: add a sanity check here that we are not overwriting existing
+ # data in self._cache
+
if stream_pos > self._earliest_known_stream_pos:
old_pos = self._entity_to_key.get(entity, None)
if old_pos is not None:
|