diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 02ab5b66ea..7e7ad0f798 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
self.client_name = client_name
self.handler = handler
self.server_name = hs.config.server_name
+ self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def buildProtocol(self, addr):
logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol(
- self.client_name, self.server_name, self._clock, self.handler
+ self.hs, self.client_name, self.server_name, self._clock, self.handler,
)
def clientConnectionLost(self, connector, reason):
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index d7ef2398fa..649312f022 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -82,7 +82,8 @@ from synapse.replication.tcp.commands import (
SyncCommand,
UserSyncCommand,
)
-from synapse.replication.tcp.streams import STREAMS_MAP
+from synapse.replication.tcp.streams import STREAMS_MAP, Stream
+from synapse.server import HomeServer
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -414,9 +415,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
# The streams the client has subscribed to and is up to date with
self.replication_streams = set() # type: Set[str]
- # The streams the client is currently subscribing to.
- self.connecting_streams = set() # type: Set[str]
-
# Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream.
self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
@@ -482,67 +480,21 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
are queued and sent once we've sent down any missed updates.
"""
self.replication_streams.discard(stream_name)
- self.connecting_streams.add(stream_name)
try:
- limited = True
- while limited:
- # Get missing updates
- (
- updates,
- current_token,
- limited,
- ) = await self.streamer.get_stream_updates(stream_name, token)
-
- # Send all the missing updates
- for update in updates:
- token, row = update[0], update[1]
- self.send_command(RdataCommand(stream_name, token, row))
+ # Get current stream position.
+ current_token = self.streamer.get_stream_token(stream_name)
# We send a POSITION command to ensure that they have an up to
# date token (especially useful if we didn't send any updates
# above)
self.send_command(PositionCommand(stream_name, current_token))
- # Now we can send any updates that came in while we were subscribing
- pending_rdata = self.pending_rdata.pop(stream_name, [])
- updates = []
- for token, update in pending_rdata:
- # If the token is null, it is part of a batch update. Batches
- # are multiple updates that share a single token. To denote
- # this, the token is set to None for all tokens in the batch
- # except for the last. If we find a None token, we keep looking
- # through tokens until we find one that is not None and then
- # process all previous updates in the batch as if they had the
- # final token.
- if token is None:
- # Store this update as part of a batch
- updates.append(update)
- continue
-
- if token <= current_token:
- # This update or batch of updates is older than
- # current_token, dismiss it
- updates = []
- continue
-
- updates.append(update)
-
- # Send all updates that are part of this batch with the
- # found token
- for update in updates:
- self.send_command(RdataCommand(stream_name, token, update))
-
- # Clear stored updates
- updates = []
-
# They're now fully subscribed
self.replication_streams.add(stream_name)
except Exception as e:
logger.exception("[%s] Failed to handle REPLICATE command", self.id())
self.send_error("failed to handle replicate: %r", e)
- finally:
- self.connecting_streams.discard(stream_name)
def stream_update(self, stream_name, token, data):
"""Called when a new update is available to stream to clients.
@@ -552,10 +504,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name in self.replication_streams:
# The client is subscribed to the stream
self.send_command(RdataCommand(stream_name, token, data))
- elif stream_name in self.connecting_streams:
- # The client is being subscribed to the stream
- logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
- self.pending_rdata.setdefault(stream_name, []).append((token, data))
else:
# The client isn't subscribed
logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
@@ -642,6 +590,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
def __init__(
self,
+ hs: HomeServer,
client_name: str,
server_name: str,
clock: Clock,
@@ -653,6 +602,10 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name
self.handler = handler
+ self.streams = {
+ stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
+ } # type: Dict[str, Stream]
+
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
@@ -660,7 +613,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
- self.pending_batches = {} # type: Dict[str, Any]
+ self.pending_batches = {} # type: Dict[str, List[Any]]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
@@ -701,7 +654,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
)
raise
- if cmd.token is None:
+ if cmd.token is None or stream_name in self.streams_connecting:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row)
@@ -711,14 +664,46 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
rows.append(row)
await self.handler.on_rdata(stream_name, cmd.token, rows)
- async def on_POSITION(self, cmd):
+ async def on_POSITION(self, cmd: PositionCommand):
+ stream = self.streams.get(cmd.stream_name)
+ if not stream:
+ logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
+ return
+
+ # Find where we previously streamed up to.
+ current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
+ if current_token is None:
+ logger.warning(
+ "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
+ )
+ return
+
+ # Fetch all updates between then and now.
+ limited = True
+ while limited:
+ updates, current_token, limited = await stream.get_updates_since(
+ current_token, cmd.token
+ )
+ if updates:
+ await self.handler.on_rdata(
+ cmd.stream_name,
+ current_token,
+ [stream.parse_row(update[1]) for update in updates],
+ )
+
+ # We've now caught up to position sent to us, notify handler.
+ await self.handler.on_position(cmd.stream_name, cmd.token)
+
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
- await self.handler.on_position(cmd.stream_name, cmd.token)
+ # Handle any RDATA that came in while we were catching up.
+ rows = self.pending_batches.pop(cmd.stream_name, [])
+ if rows:
+ await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
async def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 5be31024b7..757129b6d5 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -227,8 +227,7 @@ class ReplicationStreamer(object):
self.pending_updates = False
self.is_looping = False
- @measure_func("repl.get_stream_updates")
- async def get_stream_updates(self, stream_name, token):
+ def get_stream_token(self, stream_name):
"""For a given stream get all updates since token. This is called when
a client first subscribes to a stream.
"""
@@ -236,7 +235,7 @@ class ReplicationStreamer(object):
if not stream:
raise Exception("unknown stream %s", stream_name)
- return await stream.get_updates_since(token, stream.current_token())
+ return stream.current_token()
@measure_func("repl.federation_ack")
def federation_ack(self, token):
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index c3b9a90ca5..6f5da99f85 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -27,7 +27,8 @@ Each stream is defined by the following information:
from typing import Dict, Type
-from . import _base, events, federation
+from synapse.replication.tcp.streams import _base, events, federation
+from synapse.replication.tcp.streams._base import Stream
STREAMS_MAP = {
stream.NAME: stream
@@ -50,3 +51,6 @@ STREAMS_MAP = {
_base.UserSignatureStream,
)
} # type: Dict[str, Type[_base.Stream]]
+
+
+__all__ = ["Stream", "STREAMS_MAP"]
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index e96ad4ca4e..b7a61e22f2 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from mock import Mock
from synapse.replication.tcp.commands import ReplicateCommand
@@ -29,19 +30,37 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = server_factory.streamer
- server = server_factory.buildProtocol(None)
+ self.server = server_factory.buildProtocol(None)
- # build a replication client, with a dummy handler
- handler_factory = Mock()
- self.test_handler = TestReplicationClientHandler()
- self.test_handler.factory = handler_factory
+ self.test_handler = Mock(wraps=TestReplicationClientHandler())
self.client = ClientReplicationStreamProtocol(
- "client", "test", clock, self.test_handler
+ hs, "client", "test", clock, self.test_handler,
)
- # wire them together
- self.client.makeConnection(FakeTransport(server, reactor))
- server.makeConnection(FakeTransport(self.client, reactor))
+ self._client_transport = None
+ self._server_transport = None
+
+ def reconnect(self):
+ if self._client_transport:
+ self.client.close()
+
+ if self._server_transport:
+ self.server.close()
+
+ self._client_transport = FakeTransport(self.server, self.reactor)
+ self.client.makeConnection(self._client_transport)
+
+ self._server_transport = FakeTransport(self.client, self.reactor)
+ self.server.makeConnection(self._server_transport)
+
+ def disconnect(self):
+ if self._client_transport:
+ self._client_transport = None
+ self.client.close()
+
+ if self._server_transport:
+ self._server_transport = None
+ self.server.close()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
@@ -59,10 +78,15 @@ class TestReplicationClientHandler(object):
"""Drop-in for ReplicationClientHandler which just collects RDATA rows"""
def __init__(self):
- self.received_rdata_rows = []
+ self.streams = set()
+ self._received_rdata_rows = []
def get_streams_to_replicate(self):
- return {}
+ positions = {s: 0 for s in self.streams}
+ for stream, token, _ in self._received_rdata_rows:
+ if stream in self.streams:
+ positions[stream] = max(token, positions.get(stream, 0))
+ return positions
def get_currently_syncing_users(self):
return []
@@ -73,6 +97,9 @@ class TestReplicationClientHandler(object):
def finished_connecting(self):
pass
+ async def on_position(self, stream_name, token):
+ """Called when we get new position data."""
+
async def on_rdata(self, stream_name, token, rows):
for r in rows:
- self.received_rdata_rows.append((stream_name, token, r))
+ self._received_rdata_rows.append((stream_name, token, r))
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index d5a99f6caa..28862b2fe5 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -17,30 +17,64 @@ from synapse.replication.tcp.streams._base import ReceiptsStreamRow
from tests.replication.tcp.streams._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
-ROOM_ID = "!room:blue"
-EVENT_ID = "$event:blue"
class ReceiptsStreamTestCase(BaseStreamTestCase):
def test_receipt(self):
+ self.reconnect()
+
# make the client subscribe to the receipts stream
self.replicate_stream("receipts", "NOW")
+ self.test_handler.streams.add("receipts")
# tell the master to send a new receipt
self.get_success(
self.hs.get_datastore().insert_receipt(
- ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1}
+ "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
)
)
self.replicate()
# there should be one RDATA command
- rdata_rows = self.test_handler.received_rdata_rows
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
- self.assertEqual(rdata_rows[0][0], "receipts")
- row = rdata_rows[0][2] # type: ReceiptsStreamRow
- self.assertEqual(ROOM_ID, row.room_id)
+ row = rdata_rows[0] # type: ReceiptsStreamRow
+ self.assertEqual("!room:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
- self.assertEqual(EVENT_ID, row.event_id)
+ self.assertEqual("$event:blue", row.event_id)
self.assertEqual({"a": 1}, row.data)
+
+ # Now let's disconnect and insert some data.
+ self.disconnect()
+
+ self.test_handler.on_rdata.reset_mock()
+
+ self.get_success(
+ self.hs.get_datastore().insert_receipt(
+ "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
+ )
+ )
+ self.replicate()
+
+ # Nothing should have happened as we are disconnected
+ self.test_handler.on_rdata.assert_not_called()
+
+ self.reconnect()
+ self.pump(0.1)
+
+ # We should now have caught up and get the missing data
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "receipts")
+ self.assertEqual(token, 3)
+ self.assertEqual(1, len(rdata_rows))
+
+ row = rdata_rows[0] # type: ReceiptsStreamRow
+ self.assertEqual("!room2:blue", row.room_id)
+ self.assertEqual("m.read", row.receipt_type)
+ self.assertEqual(USER_ID, row.user_id)
+ self.assertEqual("$event2:foo", row.event_id)
+ self.assertEqual({"a": 2}, row.data)
|