summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7648.bugfix1
-rw-r--r--synapse/app/generic_worker.py5
-rw-r--r--synapse/replication/tcp/commands.py4
-rw-r--r--synapse/replication/tcp/handler.py30
-rw-r--r--tests/replication/tcp/streams/test_events.py74
-rw-r--r--tests/replication/tcp/streams/test_typing.py88
6 files changed, 175 insertions, 27 deletions
diff --git a/changelog.d/7648.bugfix b/changelog.d/7648.bugfix
new file mode 100644
index 0000000000..ff2417bfb6
--- /dev/null
+++ b/changelog.d/7648.bugfix
@@ -0,0 +1 @@
+In working mode, ensure that replicated data has not already been received.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f3ec2a34ec..53c488d211 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -738,6 +738,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
         except Exception:
             logger.exception("Error processing replication")
 
+    async def on_position(self, stream_name: str, instance_name: str, token: int):
+        await super().on_position(stream_name, instance_name, token)
+        # Also call on_rdata to ensure that stream positions are properly reset.
+        await self.on_rdata(stream_name, instance_name, token, [])
+
     def stop_pusher(self, user_id, app_id, pushkey):
         if not self.notify_pushers:
             return
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index c04f622816..ea5937a20c 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -149,7 +149,7 @@ class RdataCommand(Command):
 
 
 class PositionCommand(Command):
-    """Sent by the server to tell the client the stream postition without
+    """Sent by the server to tell the client the stream position without
     needing to send an RDATA.
 
     Format::
@@ -188,7 +188,7 @@ class ErrorCommand(_SimpleCommand):
 
 
 class PingCommand(_SimpleCommand):
-    """Sent by either side as a keep alive. The data is arbitary (often timestamp)
+    """Sent by either side as a keep alive. The data is arbitrary (often timestamp)
     """
 
     NAME = "PING"
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index cbcf46f3ae..e6a2e2598b 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -112,8 +112,8 @@ class ReplicationCommandHandler:
             "replication_position", clock=self._clock
         )
 
-        # Map of stream to batched updates. See RdataCommand for info on how
-        # batching works.
+        # Map of stream name to batched updates. See RdataCommand for info on
+        # how batching works.
         self._pending_batches = {}  # type: Dict[str, List[Any]]
 
         # The factory used to create connections.
@@ -123,7 +123,8 @@ class ReplicationCommandHandler:
         # outgoing replication commands to.)
         self._connections = []  # type: List[AbstractConnection]
 
-        # For each connection, the incoming streams that are coming from that connection
+        # For each connection, the incoming stream names that are coming from
+        # that connection.
         self._streams_by_connection = {}  # type: Dict[AbstractConnection, Set[str]]
 
         LaterGauge(
@@ -310,7 +311,28 @@ class ReplicationCommandHandler:
                 # Check if this is the last of a batch of updates
                 rows = self._pending_batches.pop(stream_name, [])
                 rows.append(row)
-                await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
+
+                stream = self._streams.get(stream_name)
+                if not stream:
+                    logger.error("Got RDATA for unknown stream: %s", stream_name)
+                    return
+
+                # Find where we previously streamed up to.
+                current_token = stream.current_token(cmd.instance_name)
+
+                # Discard this data if this token is earlier than the current
+                # position. Note that streams can be reset (in which case you
+                # expect an earlier token), but that must be preceded by a
+                # POSITION command.
+                if cmd.token <= current_token:
+                    logger.debug(
+                        "Discarding RDATA from stream %s at position %s before previous position %s",
+                        stream_name,
+                        cmd.token,
+                        current_token,
+                    )
+                else:
+                    await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
 
     async def on_rdata(
         self, stream_name: str, instance_name: str, token: int, rows: list
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 51bf0ef4e9..097e1653b4 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -17,6 +17,7 @@ from typing import List, Optional
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
+from synapse.replication.tcp.commands import RdataCommand
 from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
 from synapse.replication.tcp.streams.events import (
     EventsStreamCurrentStateRow,
@@ -66,11 +67,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
         # also one state event
         state_event = self._inject_state_event()
 
-        # tell the notifier to catch up to avoid duplicate rows.
-        # workaround for https://github.com/matrix-org/synapse/issues/7360
-        # FIXME remove this when the above is fixed
-        self.replicate()
-
         # check we're testing what we think we are: no rows should yet have been
         # received
         self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -174,11 +170,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
         # one more bit of state that doesn't get rolled back
         state2 = self._inject_state_event()
 
-        # tell the notifier to catch up to avoid duplicate rows.
-        # workaround for https://github.com/matrix-org/synapse/issues/7360
-        # FIXME remove this when the above is fixed
-        self.replicate()
-
         # check we're testing what we think we are: no rows should yet have been
         # received
         self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -327,11 +318,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
             prev_events = [e.event_id]
             pl_events.append(e)
 
-        # tell the notifier to catch up to avoid duplicate rows.
-        # workaround for https://github.com/matrix-org/synapse/issues/7360
-        # FIXME remove this when the above is fixed
-        self.replicate()
-
         # check we're testing what we think we are: no rows should yet have been
         # received
         self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -378,6 +364,64 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
+    def test_backwards_stream_id(self):
+        """
+        Test that RDATA that comes after the current position should be discarded.
+        """
+        # disconnect, so that we can stack up some changes
+        self.disconnect()
+
+        # Generate an events. We inject them using inject_event so that they are
+        # not send out over replication until we call self.replicate().
+        event = self._inject_test_event()
+
+        # check we're testing what we think we are: no rows should yet have been
+        # received
+        self.assertEqual([], self.test_handler.received_rdata_rows)
+
+        # now reconnect to pull the updates
+        self.reconnect()
+        self.replicate()
+
+        # We should have received the expected single row (as well as various
+        # cache invalidation updates which we ignore).
+        received_rows = [
+            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+        ]
+
+        # There should be a single received row.
+        self.assertEqual(len(received_rows), 1)
+
+        stream_name, token, row = received_rows[0]
+        self.assertEqual("events", stream_name)
+        self.assertIsInstance(row, EventsStreamRow)
+        self.assertEqual(row.type, "ev")
+        self.assertIsInstance(row.data, EventsStreamEventRow)
+        self.assertEqual(row.data.event_id, event.event_id)
+
+        # Reset the data.
+        self.test_handler.received_rdata_rows = []
+
+        # Save the current token for later.
+        worker_events_stream = self.worker_hs.get_replication_streams()["events"]
+        prev_token = worker_events_stream.current_token("master")
+
+        # Manually send an old RDATA command, which should get dropped. This
+        # re-uses the row from above, but with an earlier stream token.
+        self.hs.get_tcp_replication().send_command(
+            RdataCommand("events", "master", 1, row)
+        )
+
+        # No updates have been received (because it was discard as old).
+        received_rows = [
+            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+        ]
+        self.assertEqual(len(received_rows), 0)
+
+        # Ensure the stream has not gone backwards.
+        current_token = worker_events_stream.current_token("master")
+        self.assertGreaterEqual(current_token, prev_token)
+
     event_count = 0
 
     def _inject_test_event(
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index fd62b26356..5acfb3e53e 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -16,10 +16,15 @@ from mock import Mock
 
 from synapse.handlers.typing import RoomMember
 from synapse.replication.tcp.streams import TypingStream
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from tests.replication._base import BaseStreamTestCase
 
 USER_ID = "@feeling:blue"
+USER_ID_2 = "@da-ba-dee:blue"
+
+ROOM_ID = "!bar:blue"
+ROOM_ID_2 = "!foo:blue"
 
 
 class TypingStreamTestCase(BaseStreamTestCase):
@@ -29,11 +34,9 @@ class TypingStreamTestCase(BaseStreamTestCase):
     def test_typing(self):
         typing = self.hs.get_typing_handler()
 
-        room_id = "!bar:blue"
-
         self.reconnect()
 
-        typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
+        typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
 
         self.reactor.advance(0)
 
@@ -46,7 +49,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]  # type: TypingStream.TypingStreamRow
-        self.assertEqual(room_id, row.room_id)
+        self.assertEqual(ROOM_ID, row.room_id)
         self.assertEqual([USER_ID], row.user_ids)
 
         # Now let's disconnect and insert some data.
@@ -54,7 +57,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
 
         self.test_handler.on_rdata.reset_mock()
 
-        typing._push_update(member=RoomMember(room_id, USER_ID), typing=False)
+        typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
 
         self.test_handler.on_rdata.assert_not_called()
 
@@ -73,5 +76,78 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]
-        self.assertEqual(room_id, row.room_id)
+        self.assertEqual(ROOM_ID, row.room_id)
+        self.assertEqual([], row.user_ids)
+
+    def test_reset(self):
+        """
+        Test what happens when a typing stream resets.
+
+        This is emulated by jumping the stream ahead, then reconnecting (which
+        sends the proper position and RDATA).
+        """
+        typing = self.hs.get_typing_handler()
+
+        self.reconnect()
+
+        typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
+
+        self.reactor.advance(0)
+
+        # We should now see an attempt to connect to the master
+        request = self.handle_http_replication_attempt()
+        self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+        self.test_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.assertEqual(stream_name, "typing")
+        self.assertEqual(1, len(rdata_rows))
+        row = rdata_rows[0]  # type: TypingStream.TypingStreamRow
+        self.assertEqual(ROOM_ID, row.room_id)
+        self.assertEqual([USER_ID], row.user_ids)
+
+        # Push the stream forward a bunch so it can be reset.
+        for i in range(100):
+            typing._push_update(
+                member=RoomMember(ROOM_ID, "@test%s:blue" % i), typing=True
+            )
+        self.reactor.advance(0)
+
+        # Disconnect.
+        self.disconnect()
+
+        # Reset the typing handler
+        self.hs.get_replication_streams()["typing"].last_token = 0
+        self.hs.get_tcp_replication()._streams["typing"].last_token = 0
+        typing._latest_room_serial = 0
+        typing._typing_stream_change_cache = StreamChangeCache(
+            "TypingStreamChangeCache", typing._latest_room_serial
+        )
+        typing._reset()
+
+        # Reconnect.
+        self.reconnect()
+        self.pump(0.1)
+
+        # We should now see an attempt to connect to the master
+        request = self.handle_http_replication_attempt()
+        self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+        # Reset the test code.
+        self.test_handler.on_rdata.reset_mock()
+        self.test_handler.on_rdata.assert_not_called()
+
+        # Push additional data.
+        typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
+        self.reactor.advance(0)
+
+        self.test_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.assertEqual(stream_name, "typing")
+        self.assertEqual(1, len(rdata_rows))
+        row = rdata_rows[0]
+        self.assertEqual(ROOM_ID_2, row.room_id)
         self.assertEqual([], row.user_ids)
+
+        # The token should have been reset.
+        self.assertEqual(token, 1)