summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-03-03 14:19:23 +0000
committerErik Johnston <erik@matrix.org>2020-03-20 15:31:49 +0000
commit1f83255de17eb2de35fc42b91ebaaaf895771aa6 (patch)
treef5c502d0792916db1aa4ee9bb7d36b2c2fb74c7f /synapse
parentAdd ability to catchup on stream by talking to master. (diff)
downloadsynapse-1f83255de17eb2de35fc42b91ebaaaf895771aa6.tar.xz
Move stream catchup to workers.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/replication/tcp/client.py3
-rw-r--r--synapse/replication/tcp/protocol.py105
-rw-r--r--synapse/replication/tcp/resource.py5
-rw-r--r--synapse/replication/tcp/streams/__init__.py6
4 files changed, 54 insertions, 65 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 02ab5b66ea..7e7ad0f798 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
         self.client_name = client_name
         self.handler = handler
         self.server_name = hs.config.server_name
+        self.hs = hs
         self._clock = hs.get_clock()  # As self.clock is defined in super class
 
         hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
     def buildProtocol(self, addr):
         logger.info("Connected to replication: %r", addr)
         return ClientReplicationStreamProtocol(
-            self.client_name, self.server_name, self._clock, self.handler
+            self.hs, self.client_name, self.server_name, self._clock, self.handler,
         )
 
     def clientConnectionLost(self, connector, reason):
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index d7ef2398fa..649312f022 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -82,7 +82,8 @@ from synapse.replication.tcp.commands import (
     SyncCommand,
     UserSyncCommand,
 )
-from synapse.replication.tcp.streams import STREAMS_MAP
+from synapse.replication.tcp.streams import STREAMS_MAP, Stream
+from synapse.server import HomeServer
 from synapse.types import Collection
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
@@ -414,9 +415,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         # The streams the client has subscribed to and is up to date with
         self.replication_streams = set()  # type: Set[str]
 
-        # The streams the client is currently subscribing to.
-        self.connecting_streams = set()  # type:  Set[str]
-
         # Map from stream name to list of updates to send once we've finished
         # subscribing the client to the stream.
         self.pending_rdata = {}  # type: Dict[str, List[Tuple[int, Any]]]
@@ -482,67 +480,21 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         are queued and sent once we've sent down any missed updates.
         """
         self.replication_streams.discard(stream_name)
-        self.connecting_streams.add(stream_name)
 
         try:
-            limited = True
-            while limited:
-                # Get missing updates
-                (
-                    updates,
-                    current_token,
-                    limited,
-                ) = await self.streamer.get_stream_updates(stream_name, token)
-
-                # Send all the missing updates
-                for update in updates:
-                    token, row = update[0], update[1]
-                    self.send_command(RdataCommand(stream_name, token, row))
+            # Get current stream position.
+            current_token = self.streamer.get_stream_token(stream_name)
 
             # We send a POSITION command to ensure that they have an up to
             # date token (especially useful if we didn't send any updates
             # above)
             self.send_command(PositionCommand(stream_name, current_token))
 
-            # Now we can send any updates that came in while we were subscribing
-            pending_rdata = self.pending_rdata.pop(stream_name, [])
-            updates = []
-            for token, update in pending_rdata:
-                # If the token is null, it is part of a batch update. Batches
-                # are multiple updates that share a single token. To denote
-                # this, the token is set to None for all tokens in the batch
-                # except for the last. If we find a None token, we keep looking
-                # through tokens until we find one that is not None and then
-                # process all previous updates in the batch as if they had the
-                # final token.
-                if token is None:
-                    # Store this update as part of a batch
-                    updates.append(update)
-                    continue
-
-                if token <= current_token:
-                    # This update or batch of updates is older than
-                    # current_token, dismiss it
-                    updates = []
-                    continue
-
-                updates.append(update)
-
-                # Send all updates that are part of this batch with the
-                # found token
-                for update in updates:
-                    self.send_command(RdataCommand(stream_name, token, update))
-
-                # Clear stored updates
-                updates = []
-
             # They're now fully subscribed
             self.replication_streams.add(stream_name)
         except Exception as e:
             logger.exception("[%s] Failed to handle REPLICATE command", self.id())
             self.send_error("failed to handle replicate: %r", e)
-        finally:
-            self.connecting_streams.discard(stream_name)
 
     def stream_update(self, stream_name, token, data):
         """Called when a new update is available to stream to clients.
@@ -552,10 +504,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         if stream_name in self.replication_streams:
             # The client is subscribed to the stream
             self.send_command(RdataCommand(stream_name, token, data))
-        elif stream_name in self.connecting_streams:
-            # The client is being subscribed to the stream
-            logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
-            self.pending_rdata.setdefault(stream_name, []).append((token, data))
         else:
             # The client isn't subscribed
             logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
@@ -642,6 +590,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
     def __init__(
         self,
+        hs: HomeServer,
         client_name: str,
         server_name: str,
         clock: Clock,
@@ -653,6 +602,10 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.server_name = server_name
         self.handler = handler
 
+        self.streams = {
+            stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
+        }  # type: Dict[str, Stream]
+
         # Set of stream names that have been subscribe to, but haven't yet
         # caught up with. This is used to track when the client has been fully
         # connected to the remote.
@@ -660,7 +613,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
         # Map of stream to batched updates. See RdataCommand for info on how
         # batching works.
-        self.pending_batches = {}  # type: Dict[str, Any]
+        self.pending_batches = {}  # type: Dict[str, List[Any]]
 
     def connectionMade(self):
         self.send_command(NameCommand(self.client_name))
@@ -701,7 +654,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             )
             raise
 
-        if cmd.token is None:
+        if cmd.token is None or stream_name in self.streams_connecting:
             # I.e. this is part of a batch of updates for this stream. Batch
             # until we get an update for the stream with a non None token
             self.pending_batches.setdefault(stream_name, []).append(row)
@@ -711,14 +664,46 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             rows.append(row)
             await self.handler.on_rdata(stream_name, cmd.token, rows)
 
-    async def on_POSITION(self, cmd):
+    async def on_POSITION(self, cmd: PositionCommand):
+        stream = self.streams.get(cmd.stream_name)
+        if not stream:
+            logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
+            return
+
+        # Find where we previously streamed up to.
+        current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
+        if current_token is None:
+            logger.warning(
+                "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
+            )
+            return
+
+        # Fetch all updates between then and now.
+        limited = True
+        while limited:
+            updates, current_token, limited = await stream.get_updates_since(
+                current_token, cmd.token
+            )
+            if updates:
+                await self.handler.on_rdata(
+                    cmd.stream_name,
+                    current_token,
+                    [stream.parse_row(update[1]) for update in updates],
+                )
+
+        # We've now caught up to position sent to us, notify handler.
+        await self.handler.on_position(cmd.stream_name, cmd.token)
+
         # When we get a `POSITION` command it means we've finished getting
         # missing updates for the given stream, and are now up to date.
         self.streams_connecting.discard(cmd.stream_name)
         if not self.streams_connecting:
             self.handler.finished_connecting()
 
-        await self.handler.on_position(cmd.stream_name, cmd.token)
+        # Handle any RDATA that came in while we were catching up.
+        rows = self.pending_batches.pop(cmd.stream_name, [])
+        if rows:
+            await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
 
     async def on_SYNC(self, cmd):
         self.handler.on_sync(cmd.data)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 5be31024b7..757129b6d5 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -227,8 +227,7 @@ class ReplicationStreamer(object):
             self.pending_updates = False
             self.is_looping = False
 
-    @measure_func("repl.get_stream_updates")
-    async def get_stream_updates(self, stream_name, token):
+    def get_stream_token(self, stream_name):
         """For a given stream get all updates since token. This is called when
         a client first subscribes to a stream.
         """
@@ -236,7 +235,7 @@ class ReplicationStreamer(object):
         if not stream:
             raise Exception("unknown stream %s", stream_name)
 
-        return await stream.get_updates_since(token, stream.current_token())
+        return stream.current_token()
 
     @measure_func("repl.federation_ack")
     def federation_ack(self, token):
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index c3b9a90ca5..6f5da99f85 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -27,7 +27,8 @@ Each stream is defined by the following information:
 
 from typing import Dict, Type
 
-from . import _base, events, federation
+from synapse.replication.tcp.streams import _base, events, federation
+from synapse.replication.tcp.streams._base import Stream
 
 STREAMS_MAP = {
     stream.NAME: stream
@@ -50,3 +51,6 @@ STREAMS_MAP = {
         _base.UserSignatureStream,
     )
 }  # type: Dict[str, Type[_base.Stream]]
+
+
+__all__ = ["Stream", "STREAMS_MAP"]