summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
authorBen Banfield-Zanin <benbz@matrix.org>2020-12-16 14:49:53 +0000
committerBen Banfield-Zanin <benbz@matrix.org>2020-12-16 14:49:53 +0000
commit0825299cfcf61079f78b7a6c5e31f5df078c291a (patch)
tree5f469584845d065c79f1f6ed4781d0624e87f4d3 /synapse/replication/tcp
parentMerge remote-tracking branch 'origin/release-v1.21.2' into bbz/info-mainline-... (diff)
parentAdd 'xmlsec1' to dependency list (diff)
downloadsynapse-github/bbz/info-mainline-1.24.0.tar.xz
Merge remote-tracking branch 'origin/release-v1.24.0' into bbz/info-mainline-1.24.0 github/bbz/info-mainline-1.24.0 bbz/info-mainline-1.24.0
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/client.py24
-rw-r--r--synapse/replication/tcp/commands.py36
-rw-r--r--synapse/replication/tcp/handler.py30
-rw-r--r--synapse/replication/tcp/protocol.py10
-rw-r--r--synapse/replication/tcp/redis.py44
-rw-r--r--synapse/replication/tcp/resource.py57
-rw-r--r--synapse/replication/tcp/streams/_base.py11
-rw-r--r--synapse/replication/tcp/streams/events.py27
8 files changed, 187 insertions, 52 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e165429cad..2618eb1e53 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -141,21 +141,25 @@ class ReplicationDataHandler:
                 if row.type != EventsStreamEventRow.TypeId:
                     continue
                 assert isinstance(row, EventsStreamRow)
+                assert isinstance(row.data, EventsStreamEventRow)
 
-                event = await self.store.get_event(
-                    row.data.event_id, allow_rejected=True
-                )
-                if event.rejected_reason:
+                if row.data.rejected:
                     continue
 
                 extra_users = ()  # type: Tuple[UserID, ...]
-                if event.type == EventTypes.Member:
-                    extra_users = (UserID.from_string(event.state_key),)
+                if row.data.type == EventTypes.Member and row.data.state_key:
+                    extra_users = (UserID.from_string(row.data.state_key),)
 
                 max_token = self.store.get_room_max_token()
                 event_pos = PersistedEventPosition(instance_name, token)
-                self.notifier.on_new_room_event(
-                    event, event_pos, max_token, extra_users
+                self.notifier.on_new_room_event_args(
+                    event_pos=event_pos,
+                    max_room_stream_token=max_token,
+                    extra_users=extra_users,
+                    room_id=row.data.room_id,
+                    event_type=row.data.type,
+                    state_key=row.data.state_key,
+                    membership=row.data.membership,
                 )
 
         # Notify any waiting deferreds. The list is ordered by position so we
@@ -191,6 +195,10 @@ class ReplicationDataHandler:
     async def on_position(self, stream_name: str, instance_name: str, token: int):
         self.store.process_replication_rows(stream_name, instance_name, token, [])
 
+        # We poke the generic "replication" notifier to wake anything up that
+        # may be streaming.
+        self.notifier.notify_replication()
+
     def on_remote_server_up(self, server: str):
         """Called when get a new REMOTE_SERVER_UP command."""
 
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 8cd47770c1..ac532ed588 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -141,15 +141,23 @@ class RdataCommand(Command):
 
 
 class PositionCommand(Command):
-    """Sent by the server to tell the client the stream position without
-    needing to send an RDATA.
+    """Sent by an instance to tell others the stream position without needing to
+    send an RDATA.
+
+    Two tokens are sent, the new position and the last position sent by the
+    instance (in an RDATA or other POSITION). The tokens are chosen so that *no*
+    rows were written by the instance between the `prev_token` and `new_token`.
+    (If an instance hasn't sent a position before then the new position can be
+    used for both.)
 
     Format::
 
-        POSITION <stream_name> <instance_name> <token>
+        POSITION <stream_name> <instance_name> <prev_token> <new_token>
 
-    On receipt of a POSITION command clients should check if they have missed
-    any updates, and if so then fetch them out of band.
+    On receipt of a POSITION command instances should check if they have missed
+    any updates, and if so then fetch them out of band. Instances can check this
+    by comparing their view of the current token for the sending instance with
+    the included `prev_token`.
 
     The `<instance_name>` is the process that sent the command and is the source
     of the stream.
@@ -157,18 +165,26 @@ class PositionCommand(Command):
 
     NAME = "POSITION"
 
-    def __init__(self, stream_name, instance_name, token):
+    def __init__(self, stream_name, instance_name, prev_token, new_token):
         self.stream_name = stream_name
         self.instance_name = instance_name
-        self.token = token
+        self.prev_token = prev_token
+        self.new_token = new_token
 
     @classmethod
     def from_line(cls, line):
-        stream_name, instance_name, token = line.split(" ", 2)
-        return cls(stream_name, instance_name, int(token))
+        stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
+        return cls(stream_name, instance_name, int(prev_token), int(new_token))
 
     def to_line(self):
-        return " ".join((self.stream_name, self.instance_name, str(self.token)))
+        return " ".join(
+            (
+                self.stream_name,
+                self.instance_name,
+                str(self.prev_token),
+                str(self.new_token),
+            )
+        )
 
 
 class ErrorCommand(_SimpleCommand):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b323841f73..95e5502bf2 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -101,8 +101,9 @@ class ReplicationCommandHandler:
         self._streams_to_replicate = []  # type: List[Stream]
 
         for stream in self._streams.values():
-            if stream.NAME == CachesStream.NAME:
-                # All workers can write to the cache invalidation stream.
+            if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME:
+                # All workers can write to the cache invalidation stream when
+                # using redis.
                 self._streams_to_replicate.append(stream)
                 continue
 
@@ -251,10 +252,9 @@ class ReplicationCommandHandler:
         using TCP.
         """
         if hs.config.redis.redis_enabled:
-            import txredisapi
-
             from synapse.replication.tcp.redis import (
                 RedisDirectTcpReplicationClientFactory,
+                lazyConnection,
             )
 
             logger.info(
@@ -271,7 +271,8 @@ class ReplicationCommandHandler:
             # connection after SUBSCRIBE is called).
 
             # First create the connection for sending commands.
-            outbound_redis_connection = txredisapi.lazyConnection(
+            outbound_redis_connection = lazyConnection(
+                reactor=hs.get_reactor(),
                 host=hs.config.redis_host,
                 port=hs.config.redis_port,
                 password=hs.config.redis.redis_password,
@@ -313,11 +314,14 @@ class ReplicationCommandHandler:
         # We respond with current position of all streams this instance
         # replicates.
         for stream in self.get_streams_to_replicate():
+            # Note that we use the current token as the prev token here (rather
+            # than stream.last_token), as we can't be sure that there have been
+            # no rows written between last token and the current token (since we
+            # might be racing with the replication sending bg process).
+            current_token = stream.current_token(self._instance_name)
             self.send_command(
                 PositionCommand(
-                    stream.NAME,
-                    self._instance_name,
-                    stream.current_token(self._instance_name),
+                    stream.NAME, self._instance_name, current_token, current_token,
                 )
             )
 
@@ -511,16 +515,16 @@ class ReplicationCommandHandler:
         # If the position token matches our current token then we're up to
         # date and there's nothing to do. Otherwise, fetch all updates
         # between then and now.
-        missing_updates = cmd.token != current_token
+        missing_updates = cmd.prev_token != current_token
         while missing_updates:
             logger.info(
                 "Fetching replication rows for '%s' between %i and %i",
                 stream_name,
                 current_token,
-                cmd.token,
+                cmd.new_token,
             )
             (updates, current_token, missing_updates) = await stream.get_updates_since(
-                cmd.instance_name, current_token, cmd.token
+                cmd.instance_name, current_token, cmd.new_token
             )
 
             # TODO: add some tests for this
@@ -536,11 +540,11 @@ class ReplicationCommandHandler:
                     [stream.parse_row(row) for row in rows],
                 )
 
-        logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+        logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token)
 
         # We've now caught up to position sent to us, notify handler.
         await self._replication_data_handler.on_position(
-            cmd.stream_name, cmd.instance_name, cmd.token
+            cmd.stream_name, cmd.instance_name, cmd.new_token
         )
 
         self._streams_by_connection.setdefault(conn, set()).add(stream_name)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 0b0d204e64..a509e599c2 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -51,10 +51,11 @@ import fcntl
 import logging
 import struct
 from inspect import isawaitable
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Optional
 
 from prometheus_client import Counter
 
+from twisted.internet import task
 from twisted.protocols.basic import LineOnlyReceiver
 from twisted.python.failure import Failure
 
@@ -152,9 +153,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         self.last_received_command = self.clock.time_msec()
         self.last_sent_command = 0
-        self.time_we_closed = None  # When we requested the connection be closed
+        # When we requested the connection be closed
+        self.time_we_closed = None  # type: Optional[int]
 
-        self.received_ping = False  # Have we reecived a ping from the other side
+        self.received_ping = False  # Have we received a ping from the other side
 
         self.state = ConnectionStates.CONNECTING
 
@@ -165,7 +167,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.pending_commands = []  # type: List[Command]
 
         # The LoopingCall for sending pings.
-        self._send_ping_loop = None
+        self._send_ping_loop = None  # type: Optional[task.LoopingCall]
 
         # a logcontext which we use for processing incoming commands. We declare it as a
         # background process so that the CPU stats get reported to prometheus.
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index f225e533de..bc6ba709a7 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
 
 import logging
 from inspect import isawaitable
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
 
 import txredisapi
 
@@ -166,7 +166,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         Args:
             cmd (Command)
         """
-        run_as_background_process("send-cmd", self._async_send_command, cmd)
+        run_as_background_process(
+            "send-cmd", self._async_send_command, cmd, bg_start_span=False
+        )
 
     async def _async_send_command(self, cmd: Command):
         """Encode a replication command and send it over our outbound connection"""
@@ -228,3 +230,41 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
         p.password = self.password
 
         return p
+
+
+def lazyConnection(
+    reactor,
+    host: str = "localhost",
+    port: int = 6379,
+    dbid: Optional[int] = None,
+    reconnect: bool = True,
+    charset: str = "utf-8",
+    password: Optional[str] = None,
+    connectTimeout: Optional[int] = None,
+    replyTimeout: Optional[int] = None,
+    convertNumbers: bool = True,
+) -> txredisapi.RedisProtocol:
+    """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
+    reactor.
+    """
+
+    isLazy = True
+    poolsize = 1
+
+    uuid = "%s:%d" % (host, port)
+    factory = txredisapi.RedisFactory(
+        uuid,
+        dbid,
+        poolsize,
+        isLazy,
+        txredisapi.ConnectionHandler,
+        charset,
+        password,
+        replyTimeout,
+        convertNumbers,
+    )
+    factory.continueTrying = reconnect
+    for x in range(poolsize):
+        reactor.connectTCP(host, port, factory, connectTimeout)
+
+    return factory.handler
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 687984e7a8..1d4ceac0f1 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -23,7 +23,9 @@ from prometheus_client import Counter
 from twisted.internet.protocol import Factory
 
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.tcp.commands import PositionCommand
 from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
+from synapse.replication.tcp.streams import EventsStream
 from synapse.util.metrics import Measure
 
 stream_updates_counter = Counter(
@@ -84,6 +86,23 @@ class ReplicationStreamer:
         # Set of streams to replicate.
         self.streams = self.command_handler.get_streams_to_replicate()
 
+        # If we have streams then we must have redis enabled or on master
+        assert (
+            not self.streams
+            or hs.config.redis.redis_enabled
+            or not hs.config.worker.worker_app
+        )
+
+        # If we are replicating an event stream we want to periodically check if
+        # we should send updated POSITIONs. We do this as a looping call rather
+        # explicitly poking when the position advances (without new data to
+        # replicate) to reduce replication traffic (otherwise each writer would
+        # likely send a POSITION for each new event received over replication).
+        #
+        # Note that if the position hasn't advanced then we won't send anything.
+        if any(EventsStream.NAME == s.NAME for s in self.streams):
+            self.clock.looping_call(self.on_notifier_poke, 1000)
+
     def on_notifier_poke(self):
         """Checks if there is actually any new data and sends it to the
         connections if there are.
@@ -91,13 +110,23 @@ class ReplicationStreamer:
         This should get called each time new data is available, even if it
         is currently being executed, so that nothing gets missed
         """
-        if not self.command_handler.connected():
+        if not self.command_handler.connected() or not self.streams:
             # Don't bother if nothing is listening. We still need to advance
             # the stream tokens otherwise they'll fall behind forever
             for stream in self.streams:
                 stream.discard_updates_and_advance()
             return
 
+        # We check up front to see if anything has actually changed, as we get
+        # poked because of changes that happened on other instances.
+        if all(
+            stream.last_token == stream.current_token(self._instance_name)
+            for stream in self.streams
+        ):
+            return
+
+        # If there are updates then we need to set this even if we're already
+        # looping, as the loop needs to know that he might need to loop again.
         self.pending_updates = True
 
         if self.is_looping:
@@ -136,6 +165,8 @@ class ReplicationStreamer:
                                 self._replication_torture_level / 1000.0
                             )
 
+                        last_token = stream.last_token
+
                         logger.debug(
                             "Getting stream: %s: %s -> %s",
                             stream.NAME,
@@ -159,6 +190,30 @@ class ReplicationStreamer:
                             )
                             stream_updates_counter.labels(stream.NAME).inc(len(updates))
 
+                        else:
+                            # The token has advanced but there is no data to
+                            # send, so we send a `POSITION` to inform other
+                            # workers of the updated position.
+                            if stream.NAME == EventsStream.NAME:
+                                # XXX: We only do this for the EventStream as it
+                                # turns out that e.g. account data streams share
+                                # their "current token" with each other, meaning
+                                # that it is *not* safe to send a POSITION.
+                                logger.info(
+                                    "Sending position: %s -> %s",
+                                    stream.NAME,
+                                    current_token,
+                                )
+                                self.command_handler.send_command(
+                                    PositionCommand(
+                                        stream.NAME,
+                                        self._instance_name,
+                                        last_token,
+                                        current_token,
+                                    )
+                                )
+                            continue
+
                         # Some streams return multiple rows with the same stream IDs,
                         # we need to make sure they get sent out in batches. We do
                         # this by setting the current token to all but the last of
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 54dccd15a6..61b282ab2d 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -240,13 +240,18 @@ class BackfillStream(Stream):
     ROW_TYPE = BackfillStreamRow
 
     def __init__(self, hs):
-        store = hs.get_datastore()
+        self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(store.get_current_backfill_token),
-            store.get_all_new_backfill_event_rows,
+            self._current_token,
+            self.store.get_all_new_backfill_event_rows,
         )
 
+    def _current_token(self, instance_name: str) -> int:
+        # The backfill stream over replication operates on *positive* numbers,
+        # which means we need to negate it.
+        return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
+
 
 class PresenceStream(Stream):
     PresenceStreamRow = namedtuple(
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index ccc7ca30d8..86a62b71eb 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -15,12 +15,15 @@
 # limitations under the License.
 import heapq
 from collections.abc import Iterable
-from typing import List, Tuple, Type
+from typing import TYPE_CHECKING, List, Optional, Tuple, Type
 
 import attr
 
 from ._base import Stream, StreamUpdateResult, Token
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 """Handling of the 'events' replication stream
 
 This stream contains rows of various types. Each row therefore contains a 'type'
@@ -81,12 +84,14 @@ class BaseEventsStreamRow:
 class EventsStreamEventRow(BaseEventsStreamRow):
     TypeId = "ev"
 
-    event_id = attr.ib()  # str
-    room_id = attr.ib()  # str
-    type = attr.ib()  # str
-    state_key = attr.ib()  # str, optional
-    redacts = attr.ib()  # str, optional
-    relates_to = attr.ib()  # str, optional
+    event_id = attr.ib(type=str)
+    room_id = attr.ib(type=str)
+    type = attr.ib(type=str)
+    state_key = attr.ib(type=Optional[str])
+    redacts = attr.ib(type=Optional[str])
+    relates_to = attr.ib(type=Optional[str])
+    membership = attr.ib(type=Optional[str])
+    rejected = attr.ib(type=bool)
 
 
 @attr.s(slots=True, frozen=True)
@@ -113,7 +118,7 @@ class EventsStream(Stream):
 
     NAME = "events"
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self._store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
@@ -155,7 +160,7 @@ class EventsStream(Stream):
         # now we fetch up to that many rows from the events table
 
         event_rows = await self._store.get_all_new_forward_event_rows(
-            from_token, current_token, target_row_count
+            instance_name, from_token, current_token, target_row_count
         )  # type: List[Tuple]
 
         # we rely on get_all_new_forward_event_rows strictly honouring the limit, so
@@ -180,7 +185,7 @@ class EventsStream(Stream):
             upper_limit,
             state_rows_limited,
         ) = await self._store.get_all_updated_current_state_deltas(
-            from_token, upper_limit, target_row_count
+            instance_name, from_token, upper_limit, target_row_count
         )
 
         limited = limited or state_rows_limited
@@ -189,7 +194,7 @@ class EventsStream(Stream):
         # not to bother with the limit.
 
         ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
-            from_token, upper_limit
+            instance_name, from_token, upper_limit
         )  # type: List[Tuple]
 
         # we now need to turn the raw database rows returned into tuples suitable