diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 02ab5b66ea..508ad1b720 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,47 +14,55 @@
# limitations under the License.
"""A replication client for use by synapse workers.
"""
-
+import heapq
import logging
-from typing import Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Tuple
-from twisted.internet import defer
+from twisted.internet.defer import Deferred
from twisted.internet.protocol import ReconnectingClientFactory
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.tcp.protocol import (
- AbstractReplicationClientHandler,
- ClientReplicationStreamProtocol,
+from synapse.api.constants import EventTypes
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+from synapse.replication.tcp.streams.events import (
+ EventsStream,
+ EventsStreamEventRow,
+ EventsStreamRow,
)
+from synapse.util.async_helpers import timeout_deferred
+from synapse.util.metrics import Measure
-from .commands import (
- Command,
- FederationAckCommand,
- InvalidateCacheCommand,
- RemoteServerUpCommand,
- RemovePusherCommand,
- UserIpCommand,
- UserSyncCommand,
-)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+ from synapse.replication.tcp.handler import ReplicationCommandHandler
logger = logging.getLogger(__name__)
-class ReplicationClientFactory(ReconnectingClientFactory):
+# How long we allow callers to wait for replication updates before timing out.
+_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30
+
+
+class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
"""Factory for building connections to the master. Will reconnect if the
connection is lost.
- Accepts a handler that will be called when new data is available or data
- is required.
+ Accepts a handler that is passed to `ClientReplicationStreamProtocol`.
"""
initialDelay = 0.1
maxDelay = 1 # Try at least once every N seconds
- def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ client_name: str,
+ command_handler: "ReplicationCommandHandler",
+ ):
self.client_name = client_name
- self.handler = handler
+ self.command_handler = command_handler
self.server_name = hs.config.server_name
+ self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -65,7 +73,11 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def buildProtocol(self, addr):
logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol(
- self.client_name, self.server_name, self._clock, self.handler
+ self.hs,
+ self.client_name,
+ self.server_name,
+ self._clock,
+ self.command_handler,
)
def clientConnectionLost(self, connector, reason):
@@ -77,168 +89,136 @@ class ReplicationClientFactory(ReconnectingClientFactory):
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
-class ReplicationClientHandler(AbstractReplicationClientHandler):
- """A base handler that can be passed to the ReplicationClientFactory.
+class ReplicationDataHandler:
+ """Handles incoming stream updates from replication.
- By default proxies incoming replication data to the SlaveStore.
+ This instance notifies the slave data store about updates. Can be subclassed
+ to handle updates in additional ways.
"""
- def __init__(self, store: BaseSlavedStore):
- self.store = store
-
- # The current connection. None if we are currently (re)connecting
- self.connection = None
-
- # Any pending commands to be sent once a new connection has been
- # established
- self.pending_commands = [] # type: List[Command]
-
- # Map from string -> deferred, to wake up when receiveing a SYNC with
- # the given string.
- # Used for tests.
- self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
-
- # The factory used to create connections.
- self.factory = None # type: Optional[ReplicationClientFactory]
-
- def start_replication(self, hs):
- """Helper method to start a replication connection to the remote server
- using TCP.
- """
- client_name = hs.config.worker_name
- self.factory = ReplicationClientFactory(hs, client_name, self)
- host = hs.config.worker_replication_host
- port = hs.config.worker_replication_port
- hs.get_reactor().connectTCP(host, port, self.factory)
-
- async def on_rdata(self, stream_name, token, rows):
+ def __init__(self, hs: "HomeServer"):
+ self.store = hs.get_datastore()
+ self.pusher_pool = hs.get_pusherpool()
+ self.notifier = hs.get_notifier()
+ self._reactor = hs.get_reactor()
+ self._clock = hs.get_clock()
+ self._streams = hs.get_replication_streams()
+ self._instance_name = hs.get_instance_name()
+
+ # Map from stream to list of deferreds waiting for the stream to
+ # arrive at a particular position. The lists are sorted by stream position.
+ self._streams_to_waiters = (
+ {}
+ ) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
+
+ async def on_rdata(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
- stream_name (str): name of the replication stream for this batch of rows
- token (int): stream token for this batch of rows
- rows (list): a list of Stream.ROW_TYPE objects as returned by
- Stream.parse_row.
- """
- logger.debug("Received rdata %s -> %s", stream_name, token)
- self.store.process_replication_rows(stream_name, token, rows)
-
- async def on_position(self, stream_name, token):
- """Called when we get new position data. By default this just pokes
- the slave store.
-
- Can be overriden in subclasses to handle more.
+ stream_name: name of the replication stream for this batch of rows
+ instance_name: the instance that wrote the rows.
+ token: stream token for this batch of rows
+ rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
- self.store.process_replication_rows(stream_name, token, [])
-
- def on_sync(self, data):
- """When we received a SYNC we wake up any deferreds that were waiting
- for the sync with the given data.
-
- Used by tests.
- """
- d = self.awaiting_syncs.pop(data, None)
- if d:
- d.callback(data)
+ self.store.process_replication_rows(stream_name, instance_name, token, rows)
+
+ if stream_name == EventsStream.NAME:
+ # We shouldn't get multiple rows per token for events stream, so
+ # we don't need to optimise this for multiple rows.
+ for row in rows:
+ if row.type != EventsStreamEventRow.TypeId:
+ continue
+ assert isinstance(row, EventsStreamRow)
+
+ event = await self.store.get_event(
+ row.data.event_id, allow_rejected=True
+ )
+ if event.rejected_reason:
+ continue
+
+ extra_users = () # type: Tuple[str, ...]
+ if event.type == EventTypes.Member:
+ extra_users = (event.state_key,)
+ max_token = self.store.get_room_max_stream_ordering()
+ self.notifier.on_new_room_event(event, token, max_token, extra_users)
+
+ await self.pusher_pool.on_new_notifications(token, token)
+
+ # Notify any waiting deferreds. The list is ordered by position so we
+ # just iterate through the list until we reach a position that is
+ # greater than the received row position.
+ waiting_list = self._streams_to_waiters.get(stream_name, [])
+
+ # Index of first item with a position after the current token, i.e we
+ # have called all deferreds before this index. If not overwritten by
+ # loop below means either a) no items in list so no-op or b) all items
+ # in list were called and so the list should be cleared. Setting it to
+ # `len(list)` works for both cases.
+ index_of_first_deferred_not_called = len(waiting_list)
+
+ for idx, (position, deferred) in enumerate(waiting_list):
+ if position <= token:
+ try:
+ with PreserveLoggingContext():
+ deferred.callback(None)
+ except Exception:
+ # The deferred has been cancelled or timed out.
+ pass
+ else:
+ # The list is sorted by position so we don't need to continue
+ # checking any futher entries in the list.
+ index_of_first_deferred_not_called = idx
+ break
+
+ # Drop all entries in the waiting list that were called in the above
+ # loop. (This maintains the order so no need to resort)
+ waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
+
+ async def on_position(self, stream_name: str, instance_name: str, token: int):
+ self.store.process_replication_rows(stream_name, instance_name, token, [])
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
- def get_streams_to_replicate(self) -> Dict[str, int]:
- """Called when a new connection has been established and we need to
- subscribe to streams.
-
- Returns:
- map from stream name to the most recent update we have for
- that stream (ie, the point we want to start replicating from)
+ async def wait_for_stream_position(
+ self, instance_name: str, stream_name: str, position: int
+ ):
+ """Wait until this instance has received updates up to and including
+ the given stream position.
"""
- args = self.store.stream_positions()
- user_account_data = args.pop("user_account_data", None)
- room_account_data = args.pop("room_account_data", None)
- if user_account_data:
- args["account_data"] = user_account_data
- elif room_account_data:
- args["account_data"] = room_account_data
-
- return args
-
- def get_currently_syncing_users(self):
- """Get the list of currently syncing users (if any). This is called
- when a connection has been established and we need to send the
- currently syncing users. (Overriden by the synchrotron's only)
- """
- return []
- def send_command(self, cmd):
- """Send a command to master (when we get establish a connection if we
- don't have one already.)
- """
- if self.connection:
- self.connection.send_command(cmd)
- else:
- logger.warning("Queuing command as not connected: %r", cmd.NAME)
- self.pending_commands.append(cmd)
-
- def send_federation_ack(self, token):
- """Ack data for the federation stream. This allows the master to drop
- data stored purely in memory.
- """
- self.send_command(FederationAckCommand(token))
-
- def send_user_sync(self, user_id, is_syncing, last_sync_ms):
- """Poke the master that a user has started/stopped syncing.
- """
- self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
-
- def send_remove_pusher(self, app_id, push_key, user_id):
- """Poke the master to remove a pusher for a user
- """
- cmd = RemovePusherCommand(app_id, push_key, user_id)
- self.send_command(cmd)
-
- def send_invalidate_cache(self, cache_func, keys):
- """Poke the master to invalidate a cache.
- """
- cmd = InvalidateCacheCommand(cache_func.__name__, keys)
- self.send_command(cmd)
-
- def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
- """Tell the master that the user made a request.
- """
- cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
- self.send_command(cmd)
-
- def send_remote_server_up(self, server: str):
- self.send_command(RemoteServerUpCommand(server))
-
- def await_sync(self, data):
- """Returns a deferred that is resolved when we receive a SYNC command
- with given data.
+ if instance_name == self._instance_name:
+ # We don't get told about updates written by this process, and
+ # anyway in that case we don't need to wait.
+ return
+
+ current_position = self._streams[stream_name].current_token(self._instance_name)
+ if position <= current_position:
+ # We're already past the position
+ return
+
+ # Create a new deferred that times out after N seconds, as we don't want
+ # to wedge here forever.
+ deferred = Deferred()
+ deferred = timeout_deferred(
+ deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
+ )
- [Not currently] used by tests.
- """
- return self.awaiting_syncs.setdefault(data, defer.Deferred())
+ waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
- def update_connection(self, connection):
- """Called when a connection has been established (or lost with None).
- """
- self.connection = connection
- if connection:
- for cmd in self.pending_commands:
- connection.send_command(cmd)
- self.pending_commands = []
-
- def finished_connecting(self):
- """Called when we have successfully subscribed and caught up to all
- streams we're interested in.
- """
- logger.info("Finished connecting to server")
+ # We insert into the list using heapq as it is more efficient than
+ # pushing then resorting each time.
+ heapq.heappush(waiting_list, (position, deferred))
- # We don't reset the delay any earlier as otherwise if there is a
- # problem during start up we'll end up tight looping connecting to the
- # server.
- if self.factory:
- self.factory.resetDelay()
+ # We measure here to get in flight counts and average waiting time.
+ with Measure(self._clock, "repl.wait_for_stream_position"):
+ logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
+ await make_deferred_yieldable(deferred)
+ logger.info(
+ "Finished waiting for repl stream %r to reach %s", stream_name, position
+ )
|