summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/_base.py2
-rw-r--r--synapse/replication/http/federation.py2
-rw-r--r--synapse/replication/http/membership.py2
-rw-r--r--synapse/replication/tcp/client.py4
-rw-r--r--synapse/replication/tcp/commands.py36
-rw-r--r--synapse/replication/tcp/handler.py30
-rw-r--r--synapse/replication/tcp/protocol.py10
-rw-r--r--synapse/replication/tcp/redis.py40
-rw-r--r--synapse/replication/tcp/resource.py47
-rw-r--r--synapse/replication/tcp/streams/_base.py11
-rw-r--r--synapse/replication/tcp/streams/events.py6
11 files changed, 152 insertions, 38 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 64edadb624..2b3972cb14 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -92,7 +92,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         if self.CACHE:
             self.response_cache = ResponseCache(
                 hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
-            )
+            )  # type: ResponseCache[str]
 
         # We reserve `instance_name` as a parameter to sending requests, so we
         # assert here that sub classes don't try and use the name.
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 5393b9a9e7..b4f4a68b5c 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -62,7 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         self.store = hs.get_datastore()
         self.storage = hs.get_storage()
         self.clock = hs.get_clock()
-        self.federation_handler = hs.get_handlers().federation_handler
+        self.federation_handler = hs.get_federation_handler()
 
     @staticmethod
     async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 30680baee8..e7cc74a5d2 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -47,7 +47,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
     def __init__(self, hs):
         super().__init__(hs)
 
-        self.federation_handler = hs.get_handlers().federation_handler
+        self.federation_handler = hs.get_federation_handler()
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
 
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e165429cad..e27ee216f0 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -191,6 +191,10 @@ class ReplicationDataHandler:
     async def on_position(self, stream_name: str, instance_name: str, token: int):
         self.store.process_replication_rows(stream_name, instance_name, token, [])
 
+        # We poke the generic "replication" notifier to wake anything up that
+        # may be streaming.
+        self.notifier.notify_replication()
+
     def on_remote_server_up(self, server: str):
         """Called when get a new REMOTE_SERVER_UP command."""
 
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 8cd47770c1..ac532ed588 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -141,15 +141,23 @@ class RdataCommand(Command):
 
 
 class PositionCommand(Command):
-    """Sent by the server to tell the client the stream position without
-    needing to send an RDATA.
+    """Sent by an instance to tell others the stream position without needing to
+    send an RDATA.
+
+    Two tokens are sent, the new position and the last position sent by the
+    instance (in an RDATA or other POSITION). The tokens are chosen so that *no*
+    rows were written by the instance between the `prev_token` and `new_token`.
+    (If an instance hasn't sent a position before then the new position can be
+    used for both.)
 
     Format::
 
-        POSITION <stream_name> <instance_name> <token>
+        POSITION <stream_name> <instance_name> <prev_token> <new_token>
 
-    On receipt of a POSITION command clients should check if they have missed
-    any updates, and if so then fetch them out of band.
+    On receipt of a POSITION command instances should check if they have missed
+    any updates, and if so then fetch them out of band. Instances can check this
+    by comparing their view of the current token for the sending instance with
+    the included `prev_token`.
 
     The `<instance_name>` is the process that sent the command and is the source
     of the stream.
@@ -157,18 +165,26 @@ class PositionCommand(Command):
 
     NAME = "POSITION"
 
-    def __init__(self, stream_name, instance_name, token):
+    def __init__(self, stream_name, instance_name, prev_token, new_token):
         self.stream_name = stream_name
         self.instance_name = instance_name
-        self.token = token
+        self.prev_token = prev_token
+        self.new_token = new_token
 
     @classmethod
     def from_line(cls, line):
-        stream_name, instance_name, token = line.split(" ", 2)
-        return cls(stream_name, instance_name, int(token))
+        stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
+        return cls(stream_name, instance_name, int(prev_token), int(new_token))
 
     def to_line(self):
-        return " ".join((self.stream_name, self.instance_name, str(self.token)))
+        return " ".join(
+            (
+                self.stream_name,
+                self.instance_name,
+                str(self.prev_token),
+                str(self.new_token),
+            )
+        )
 
 
 class ErrorCommand(_SimpleCommand):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b323841f73..95e5502bf2 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -101,8 +101,9 @@ class ReplicationCommandHandler:
         self._streams_to_replicate = []  # type: List[Stream]
 
         for stream in self._streams.values():
-            if stream.NAME == CachesStream.NAME:
-                # All workers can write to the cache invalidation stream.
+            if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME:
+                # All workers can write to the cache invalidation stream when
+                # using redis.
                 self._streams_to_replicate.append(stream)
                 continue
 
@@ -251,10 +252,9 @@ class ReplicationCommandHandler:
         using TCP.
         """
         if hs.config.redis.redis_enabled:
-            import txredisapi
-
             from synapse.replication.tcp.redis import (
                 RedisDirectTcpReplicationClientFactory,
+                lazyConnection,
             )
 
             logger.info(
@@ -271,7 +271,8 @@ class ReplicationCommandHandler:
             # connection after SUBSCRIBE is called).
 
             # First create the connection for sending commands.
-            outbound_redis_connection = txredisapi.lazyConnection(
+            outbound_redis_connection = lazyConnection(
+                reactor=hs.get_reactor(),
                 host=hs.config.redis_host,
                 port=hs.config.redis_port,
                 password=hs.config.redis.redis_password,
@@ -313,11 +314,14 @@ class ReplicationCommandHandler:
         # We respond with current position of all streams this instance
         # replicates.
         for stream in self.get_streams_to_replicate():
+            # Note that we use the current token as the prev token here (rather
+            # than stream.last_token), as we can't be sure that there have been
+            # no rows written between last token and the current token (since we
+            # might be racing with the replication sending bg process).
+            current_token = stream.current_token(self._instance_name)
             self.send_command(
                 PositionCommand(
-                    stream.NAME,
-                    self._instance_name,
-                    stream.current_token(self._instance_name),
+                    stream.NAME, self._instance_name, current_token, current_token,
                 )
             )
 
@@ -511,16 +515,16 @@ class ReplicationCommandHandler:
         # If the position token matches our current token then we're up to
         # date and there's nothing to do. Otherwise, fetch all updates
         # between then and now.
-        missing_updates = cmd.token != current_token
+        missing_updates = cmd.prev_token != current_token
         while missing_updates:
             logger.info(
                 "Fetching replication rows for '%s' between %i and %i",
                 stream_name,
                 current_token,
-                cmd.token,
+                cmd.new_token,
             )
             (updates, current_token, missing_updates) = await stream.get_updates_since(
-                cmd.instance_name, current_token, cmd.token
+                cmd.instance_name, current_token, cmd.new_token
             )
 
             # TODO: add some tests for this
@@ -536,11 +540,11 @@ class ReplicationCommandHandler:
                     [stream.parse_row(row) for row in rows],
                 )
 
-        logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+        logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token)
 
         # We've now caught up to position sent to us, notify handler.
         await self._replication_data_handler.on_position(
-            cmd.stream_name, cmd.instance_name, cmd.token
+            cmd.stream_name, cmd.instance_name, cmd.new_token
         )
 
         self._streams_by_connection.setdefault(conn, set()).add(stream_name)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 0b0d204e64..a509e599c2 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -51,10 +51,11 @@ import fcntl
 import logging
 import struct
 from inspect import isawaitable
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Optional
 
 from prometheus_client import Counter
 
+from twisted.internet import task
 from twisted.protocols.basic import LineOnlyReceiver
 from twisted.python.failure import Failure
 
@@ -152,9 +153,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         self.last_received_command = self.clock.time_msec()
         self.last_sent_command = 0
-        self.time_we_closed = None  # When we requested the connection be closed
+        # When we requested the connection be closed
+        self.time_we_closed = None  # type: Optional[int]
 
-        self.received_ping = False  # Have we reecived a ping from the other side
+        self.received_ping = False  # Have we received a ping from the other side
 
         self.state = ConnectionStates.CONNECTING
 
@@ -165,7 +167,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.pending_commands = []  # type: List[Command]
 
         # The LoopingCall for sending pings.
-        self._send_ping_loop = None
+        self._send_ping_loop = None  # type: Optional[task.LoopingCall]
 
         # a logcontext which we use for processing incoming commands. We declare it as a
         # background process so that the CPU stats get reported to prometheus.
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index f225e533de..de19705c1f 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
 
 import logging
 from inspect import isawaitable
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
 
 import txredisapi
 
@@ -228,3 +228,41 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
         p.password = self.password
 
         return p
+
+
+def lazyConnection(
+    reactor,
+    host: str = "localhost",
+    port: int = 6379,
+    dbid: Optional[int] = None,
+    reconnect: bool = True,
+    charset: str = "utf-8",
+    password: Optional[str] = None,
+    connectTimeout: Optional[int] = None,
+    replyTimeout: Optional[int] = None,
+    convertNumbers: bool = True,
+) -> txredisapi.RedisProtocol:
+    """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
+    reactor.
+    """
+
+    isLazy = True
+    poolsize = 1
+
+    uuid = "%s:%d" % (host, port)
+    factory = txredisapi.RedisFactory(
+        uuid,
+        dbid,
+        poolsize,
+        isLazy,
+        txredisapi.ConnectionHandler,
+        charset,
+        password,
+        replyTimeout,
+        convertNumbers,
+    )
+    factory.continueTrying = reconnect
+    for x in range(poolsize):
+        reactor.connectTCP(host, port, factory, connectTimeout)
+
+    return factory.handler
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 687984e7a8..666c13fdb7 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -23,7 +23,9 @@ from prometheus_client import Counter
 from twisted.internet.protocol import Factory
 
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.tcp.commands import PositionCommand
 from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
+from synapse.replication.tcp.streams import EventsStream
 from synapse.util.metrics import Measure
 
 stream_updates_counter = Counter(
@@ -84,6 +86,23 @@ class ReplicationStreamer:
         # Set of streams to replicate.
         self.streams = self.command_handler.get_streams_to_replicate()
 
+        # If we have streams then we must have redis enabled or on master
+        assert (
+            not self.streams
+            or hs.config.redis.redis_enabled
+            or not hs.config.worker.worker_app
+        )
+
+        # If we are replicating an event stream we want to periodically check if
+        # we should send updated POSITIONs. We do this as a looping call rather
+        # explicitly poking when the position advances (without new data to
+        # replicate) to reduce replication traffic (otherwise each writer would
+        # likely send a POSITION for each new event received over replication).
+        #
+        # Note that if the position hasn't advanced then we won't send anything.
+        if any(EventsStream.NAME == s.NAME for s in self.streams):
+            self.clock.looping_call(self.on_notifier_poke, 1000)
+
     def on_notifier_poke(self):
         """Checks if there is actually any new data and sends it to the
         connections if there are.
@@ -91,7 +110,7 @@ class ReplicationStreamer:
         This should get called each time new data is available, even if it
         is currently being executed, so that nothing gets missed
         """
-        if not self.command_handler.connected():
+        if not self.command_handler.connected() or not self.streams:
             # Don't bother if nothing is listening. We still need to advance
             # the stream tokens otherwise they'll fall behind forever
             for stream in self.streams:
@@ -136,6 +155,8 @@ class ReplicationStreamer:
                                 self._replication_torture_level / 1000.0
                             )
 
+                        last_token = stream.last_token
+
                         logger.debug(
                             "Getting stream: %s: %s -> %s",
                             stream.NAME,
@@ -159,6 +180,30 @@ class ReplicationStreamer:
                             )
                             stream_updates_counter.labels(stream.NAME).inc(len(updates))
 
+                        else:
+                            # The token has advanced but there is no data to
+                            # send, so we send a `POSITION` to inform other
+                            # workers of the updated position.
+                            if stream.NAME == EventsStream.NAME:
+                                # XXX: We only do this for the EventStream as it
+                                # turns out that e.g. account data streams share
+                                # their "current token" with each other, meaning
+                                # that it is *not* safe to send a POSITION.
+                                logger.info(
+                                    "Sending position: %s -> %s",
+                                    stream.NAME,
+                                    current_token,
+                                )
+                                self.command_handler.send_command(
+                                    PositionCommand(
+                                        stream.NAME,
+                                        self._instance_name,
+                                        last_token,
+                                        current_token,
+                                    )
+                                )
+                            continue
+
                         # Some streams return multiple rows with the same stream IDs,
                         # we need to make sure they get sent out in batches. We do
                         # this by setting the current token to all but the last of
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 54dccd15a6..61b282ab2d 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -240,13 +240,18 @@ class BackfillStream(Stream):
     ROW_TYPE = BackfillStreamRow
 
     def __init__(self, hs):
-        store = hs.get_datastore()
+        self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(store.get_current_backfill_token),
-            store.get_all_new_backfill_event_rows,
+            self._current_token,
+            self.store.get_all_new_backfill_event_rows,
         )
 
+    def _current_token(self, instance_name: str) -> int:
+        # The backfill stream over replication operates on *positive* numbers,
+        # which means we need to negate it.
+        return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
+
 
 class PresenceStream(Stream):
     PresenceStreamRow = namedtuple(
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index ccc7ca30d8..82e9e0d64e 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -155,7 +155,7 @@ class EventsStream(Stream):
         # now we fetch up to that many rows from the events table
 
         event_rows = await self._store.get_all_new_forward_event_rows(
-            from_token, current_token, target_row_count
+            instance_name, from_token, current_token, target_row_count
         )  # type: List[Tuple]
 
         # we rely on get_all_new_forward_event_rows strictly honouring the limit, so
@@ -180,7 +180,7 @@ class EventsStream(Stream):
             upper_limit,
             state_rows_limited,
         ) = await self._store.get_all_updated_current_state_deltas(
-            from_token, upper_limit, target_row_count
+            instance_name, from_token, upper_limit, target_row_count
         )
 
         limited = limited or state_rows_limited
@@ -189,7 +189,7 @@ class EventsStream(Stream):
         # not to bother with the limit.
 
         ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
-            from_token, upper_limit
+            instance_name, from_token, upper_limit
         )  # type: List[Tuple]
 
         # we now need to turn the raw database rows returned into tuples suitable