diff --git a/changelog.d/7648.bugfix b/changelog.d/7648.bugfix
new file mode 100644
index 0000000000..ff2417bfb6
--- /dev/null
+++ b/changelog.d/7648.bugfix
@@ -0,0 +1 @@
+In working mode, ensure that replicated data has not already been received.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f3ec2a34ec..53c488d211 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -738,6 +738,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
except Exception:
logger.exception("Error processing replication")
+ async def on_position(self, stream_name: str, instance_name: str, token: int):
+ await super().on_position(stream_name, instance_name, token)
+ # Also call on_rdata to ensure that stream positions are properly reset.
+ await self.on_rdata(stream_name, instance_name, token, [])
+
def stop_pusher(self, user_id, app_id, pushkey):
if not self.notify_pushers:
return
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index c04f622816..ea5937a20c 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -149,7 +149,7 @@ class RdataCommand(Command):
class PositionCommand(Command):
- """Sent by the server to tell the client the stream postition without
+ """Sent by the server to tell the client the stream position without
needing to send an RDATA.
Format::
@@ -188,7 +188,7 @@ class ErrorCommand(_SimpleCommand):
class PingCommand(_SimpleCommand):
- """Sent by either side as a keep alive. The data is arbitary (often timestamp)
+ """Sent by either side as a keep alive. The data is arbitrary (often timestamp)
"""
NAME = "PING"
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index cbcf46f3ae..e6a2e2598b 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -112,8 +112,8 @@ class ReplicationCommandHandler:
"replication_position", clock=self._clock
)
- # Map of stream to batched updates. See RdataCommand for info on how
- # batching works.
+ # Map of stream name to batched updates. See RdataCommand for info on
+ # how batching works.
self._pending_batches = {} # type: Dict[str, List[Any]]
# The factory used to create connections.
@@ -123,7 +123,8 @@ class ReplicationCommandHandler:
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]
- # For each connection, the incoming streams that are coming from that connection
+ # For each connection, the incoming stream names that are coming from
+ # that connection.
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
LaterGauge(
@@ -310,7 +311,28 @@ class ReplicationCommandHandler:
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
- await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
+
+ stream = self._streams.get(stream_name)
+ if not stream:
+ logger.error("Got RDATA for unknown stream: %s", stream_name)
+ return
+
+ # Find where we previously streamed up to.
+ current_token = stream.current_token(cmd.instance_name)
+
+ # Discard this data if this token is earlier than the current
+ # position. Note that streams can be reset (in which case you
+ # expect an earlier token), but that must be preceded by a
+ # POSITION command.
+ if cmd.token <= current_token:
+ logger.debug(
+ "Discarding RDATA from stream %s at position %s before previous position %s",
+ stream_name,
+ cmd.token,
+ current_token,
+ )
+ else:
+ await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 51bf0ef4e9..097e1653b4 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -17,6 +17,7 @@ from typing import List, Optional
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
+from synapse.replication.tcp.commands import RdataCommand
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow,
@@ -66,11 +67,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
# also one state event
state_event = self._inject_state_event()
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -174,11 +170,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
# one more bit of state that doesn't get rolled back
state2 = self._inject_state_event()
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -327,11 +318,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
prev_events = [e.event_id]
pl_events.append(e)
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -378,6 +364,64 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
+ def test_backwards_stream_id(self):
+ """
+ Test that RDATA that comes after the current position should be discarded.
+ """
+ # disconnect, so that we can stack up some changes
+ self.disconnect()
+
+ # Generate an events. We inject them using inject_event so that they are
+ # not send out over replication until we call self.replicate().
+ event = self._inject_test_event()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # We should have received the expected single row (as well as various
+ # cache invalidation updates which we ignore).
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+
+ # There should be a single received row.
+ self.assertEqual(len(received_rows), 1)
+
+ stream_name, token, row = received_rows[0]
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, event.event_id)
+
+ # Reset the data.
+ self.test_handler.received_rdata_rows = []
+
+ # Save the current token for later.
+ worker_events_stream = self.worker_hs.get_replication_streams()["events"]
+ prev_token = worker_events_stream.current_token("master")
+
+ # Manually send an old RDATA command, which should get dropped. This
+ # re-uses the row from above, but with an earlier stream token.
+ self.hs.get_tcp_replication().send_command(
+ RdataCommand("events", "master", 1, row)
+ )
+
+ # No updates have been received (because it was discard as old).
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+ self.assertEqual(len(received_rows), 0)
+
+ # Ensure the stream has not gone backwards.
+ current_token = worker_events_stream.current_token("master")
+ self.assertGreaterEqual(current_token, prev_token)
+
event_count = 0
def _inject_test_event(
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index fd62b26356..5acfb3e53e 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -16,10 +16,15 @@ from mock import Mock
from synapse.handlers.typing import RoomMember
from synapse.replication.tcp.streams import TypingStream
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from tests.replication._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
+USER_ID_2 = "@da-ba-dee:blue"
+
+ROOM_ID = "!bar:blue"
+ROOM_ID_2 = "!foo:blue"
class TypingStreamTestCase(BaseStreamTestCase):
@@ -29,11 +34,9 @@ class TypingStreamTestCase(BaseStreamTestCase):
def test_typing(self):
typing = self.hs.get_typing_handler()
- room_id = "!bar:blue"
-
self.reconnect()
- typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
+ typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
self.reactor.advance(0)
@@ -46,7 +49,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
- self.assertEqual(room_id, row.room_id)
+ self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)
# Now let's disconnect and insert some data.
@@ -54,7 +57,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.test_handler.on_rdata.reset_mock()
- typing._push_update(member=RoomMember(room_id, USER_ID), typing=False)
+ typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
self.test_handler.on_rdata.assert_not_called()
@@ -73,5 +76,78 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
- self.assertEqual(room_id, row.room_id)
+ self.assertEqual(ROOM_ID, row.room_id)
+ self.assertEqual([], row.user_ids)
+
+ def test_reset(self):
+ """
+ Test what happens when a typing stream resets.
+
+ This is emulated by jumping the stream ahead, then reconnecting (which
+ sends the proper position and RDATA).
+ """
+ typing = self.hs.get_typing_handler()
+
+ self.reconnect()
+
+ typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
+
+ self.reactor.advance(0)
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "typing")
+ self.assertEqual(1, len(rdata_rows))
+ row = rdata_rows[0] # type: TypingStream.TypingStreamRow
+ self.assertEqual(ROOM_ID, row.room_id)
+ self.assertEqual([USER_ID], row.user_ids)
+
+ # Push the stream forward a bunch so it can be reset.
+ for i in range(100):
+ typing._push_update(
+ member=RoomMember(ROOM_ID, "@test%s:blue" % i), typing=True
+ )
+ self.reactor.advance(0)
+
+ # Disconnect.
+ self.disconnect()
+
+ # Reset the typing handler
+ self.hs.get_replication_streams()["typing"].last_token = 0
+ self.hs.get_tcp_replication()._streams["typing"].last_token = 0
+ typing._latest_room_serial = 0
+ typing._typing_stream_change_cache = StreamChangeCache(
+ "TypingStreamChangeCache", typing._latest_room_serial
+ )
+ typing._reset()
+
+ # Reconnect.
+ self.reconnect()
+ self.pump(0.1)
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+ # Reset the test code.
+ self.test_handler.on_rdata.reset_mock()
+ self.test_handler.on_rdata.assert_not_called()
+
+ # Push additional data.
+ typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
+ self.reactor.advance(0)
+
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "typing")
+ self.assertEqual(1, len(rdata_rows))
+ row = rdata_rows[0]
+ self.assertEqual(ROOM_ID_2, row.room_id)
self.assertEqual([], row.user_ids)
+
+ # The token should have been reset.
+ self.assertEqual(token, 1)
|