summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8499.misc1
-rw-r--r--docs/tcp_replication.md13
-rw-r--r--synapse/replication/tcp/client.py4
-rw-r--r--synapse/replication/tcp/commands.py36
-rw-r--r--synapse/replication/tcp/handler.py24
-rw-r--r--synapse/replication/tcp/resource.py47
-rw-r--r--synapse/storage/util/id_generators.py10
-rw-r--r--synapse/storage/util/sequence.py2
-rw-r--r--tests/storage/test_id_generators.py25
9 files changed, 128 insertions, 34 deletions
diff --git a/changelog.d/8499.misc b/changelog.d/8499.misc
new file mode 100644
index 0000000000..237cb3b311
--- /dev/null
+++ b/changelog.d/8499.misc
@@ -0,0 +1 @@
+Allow events to be sent to clients sooner when using sharded event persisters.
diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index db318baa9d..ad145439b4 100644
--- a/docs/tcp_replication.md
+++ b/docs/tcp_replication.md
@@ -15,7 +15,7 @@ example flow would be (where '>' indicates master to worker and
 
     > SERVER example.com
     < REPLICATE
-    > POSITION events master 53
+    > POSITION events master 53 53
     > RDATA events master 54 ["$foo1:bar.com", ...]
     > RDATA events master 55 ["$foo4:bar.com", ...]
 
@@ -138,9 +138,9 @@ the wire:
     < NAME synapse.app.appservice
     < PING 1490197665618
     < REPLICATE
-    > POSITION events master 1
-    > POSITION backfill master 1
-    > POSITION caches master 1
+    > POSITION events master 1 1
+    > POSITION backfill master 1 1
+    > POSITION caches master 1 1
     > RDATA caches master 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513]
     > RDATA events master 14 ["$149019767112vOHxz:localhost:8823",
         "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]
@@ -185,6 +185,11 @@ client (C):
    updates via HTTP API, rather than via the DB, then processes should make the
    request to the appropriate process.
 
+   Two positions are included, the "new" position and the last position sent respectively.
+   This allows servers to tell instances that the positions have advanced but no
+   data has been written, without clients needlessly checking to see if they
+   have missed any updates.
+
 #### ERROR (S, C)
 
    There was an error
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e165429cad..e27ee216f0 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -191,6 +191,10 @@ class ReplicationDataHandler:
     async def on_position(self, stream_name: str, instance_name: str, token: int):
         self.store.process_replication_rows(stream_name, instance_name, token, [])
 
+        # We poke the generic "replication" notifier to wake anything up that
+        # may be streaming.
+        self.notifier.notify_replication()
+
     def on_remote_server_up(self, server: str):
         """Called when get a new REMOTE_SERVER_UP command."""
 
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 8cd47770c1..ac532ed588 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -141,15 +141,23 @@ class RdataCommand(Command):
 
 
 class PositionCommand(Command):
-    """Sent by the server to tell the client the stream position without
-    needing to send an RDATA.
+    """Sent by an instance to tell others the stream position without needing to
+    send an RDATA.
+
+    Two tokens are sent, the new position and the last position sent by the
+    instance (in an RDATA or other POSITION). The tokens are chosen so that *no*
+    rows were written by the instance between the `prev_token` and `new_token`.
+    (If an instance hasn't sent a position before then the new position can be
+    used for both.)
 
     Format::
 
-        POSITION <stream_name> <instance_name> <token>
+        POSITION <stream_name> <instance_name> <prev_token> <new_token>
 
-    On receipt of a POSITION command clients should check if they have missed
-    any updates, and if so then fetch them out of band.
+    On receipt of a POSITION command instances should check if they have missed
+    any updates, and if so then fetch them out of band. Instances can check this
+    by comparing their view of the current token for the sending instance with
+    the included `prev_token`.
 
     The `<instance_name>` is the process that sent the command and is the source
     of the stream.
@@ -157,18 +165,26 @@ class PositionCommand(Command):
 
     NAME = "POSITION"
 
-    def __init__(self, stream_name, instance_name, token):
+    def __init__(self, stream_name, instance_name, prev_token, new_token):
         self.stream_name = stream_name
         self.instance_name = instance_name
-        self.token = token
+        self.prev_token = prev_token
+        self.new_token = new_token
 
     @classmethod
     def from_line(cls, line):
-        stream_name, instance_name, token = line.split(" ", 2)
-        return cls(stream_name, instance_name, int(token))
+        stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
+        return cls(stream_name, instance_name, int(prev_token), int(new_token))
 
     def to_line(self):
-        return " ".join((self.stream_name, self.instance_name, str(self.token)))
+        return " ".join(
+            (
+                self.stream_name,
+                self.instance_name,
+                str(self.prev_token),
+                str(self.new_token),
+            )
+        )
 
 
 class ErrorCommand(_SimpleCommand):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index e92da7b263..95e5502bf2 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -101,8 +101,9 @@ class ReplicationCommandHandler:
         self._streams_to_replicate = []  # type: List[Stream]
 
         for stream in self._streams.values():
-            if stream.NAME == CachesStream.NAME:
-                # All workers can write to the cache invalidation stream.
+            if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME:
+                # All workers can write to the cache invalidation stream when
+                # using redis.
                 self._streams_to_replicate.append(stream)
                 continue
 
@@ -313,11 +314,14 @@ class ReplicationCommandHandler:
         # We respond with current position of all streams this instance
         # replicates.
         for stream in self.get_streams_to_replicate():
+            # Note that we use the current token as the prev token here (rather
+            # than stream.last_token), as we can't be sure that there have been
+            # no rows written between last token and the current token (since we
+            # might be racing with the replication sending bg process).
+            current_token = stream.current_token(self._instance_name)
             self.send_command(
                 PositionCommand(
-                    stream.NAME,
-                    self._instance_name,
-                    stream.current_token(self._instance_name),
+                    stream.NAME, self._instance_name, current_token, current_token,
                 )
             )
 
@@ -511,16 +515,16 @@ class ReplicationCommandHandler:
         # If the position token matches our current token then we're up to
         # date and there's nothing to do. Otherwise, fetch all updates
         # between then and now.
-        missing_updates = cmd.token != current_token
+        missing_updates = cmd.prev_token != current_token
         while missing_updates:
             logger.info(
                 "Fetching replication rows for '%s' between %i and %i",
                 stream_name,
                 current_token,
-                cmd.token,
+                cmd.new_token,
             )
             (updates, current_token, missing_updates) = await stream.get_updates_since(
-                cmd.instance_name, current_token, cmd.token
+                cmd.instance_name, current_token, cmd.new_token
             )
 
             # TODO: add some tests for this
@@ -536,11 +540,11 @@ class ReplicationCommandHandler:
                     [stream.parse_row(row) for row in rows],
                 )
 
-        logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+        logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token)
 
         # We've now caught up to position sent to us, notify handler.
         await self._replication_data_handler.on_position(
-            cmd.stream_name, cmd.instance_name, cmd.token
+            cmd.stream_name, cmd.instance_name, cmd.new_token
         )
 
         self._streams_by_connection.setdefault(conn, set()).add(stream_name)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 687984e7a8..666c13fdb7 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -23,7 +23,9 @@ from prometheus_client import Counter
 from twisted.internet.protocol import Factory
 
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.tcp.commands import PositionCommand
 from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
+from synapse.replication.tcp.streams import EventsStream
 from synapse.util.metrics import Measure
 
 stream_updates_counter = Counter(
@@ -84,6 +86,23 @@ class ReplicationStreamer:
         # Set of streams to replicate.
         self.streams = self.command_handler.get_streams_to_replicate()
 
+        # If we have streams then we must have redis enabled or on master
+        assert (
+            not self.streams
+            or hs.config.redis.redis_enabled
+            or not hs.config.worker.worker_app
+        )
+
+        # If we are replicating an event stream we want to periodically check if
+        # we should send updated POSITIONs. We do this as a looping call rather
+        # explicitly poking when the position advances (without new data to
+        # replicate) to reduce replication traffic (otherwise each writer would
+        # likely send a POSITION for each new event received over replication).
+        #
+        # Note that if the position hasn't advanced then we won't send anything.
+        if any(EventsStream.NAME == s.NAME for s in self.streams):
+            self.clock.looping_call(self.on_notifier_poke, 1000)
+
     def on_notifier_poke(self):
         """Checks if there is actually any new data and sends it to the
         connections if there are.
@@ -91,7 +110,7 @@ class ReplicationStreamer:
         This should get called each time new data is available, even if it
         is currently being executed, so that nothing gets missed
         """
-        if not self.command_handler.connected():
+        if not self.command_handler.connected() or not self.streams:
             # Don't bother if nothing is listening. We still need to advance
             # the stream tokens otherwise they'll fall behind forever
             for stream in self.streams:
@@ -136,6 +155,8 @@ class ReplicationStreamer:
                                 self._replication_torture_level / 1000.0
                             )
 
+                        last_token = stream.last_token
+
                         logger.debug(
                             "Getting stream: %s: %s -> %s",
                             stream.NAME,
@@ -159,6 +180,30 @@ class ReplicationStreamer:
                             )
                             stream_updates_counter.labels(stream.NAME).inc(len(updates))
 
+                        else:
+                            # The token has advanced but there is no data to
+                            # send, so we send a `POSITION` to inform other
+                            # workers of the updated position.
+                            if stream.NAME == EventsStream.NAME:
+                                # XXX: We only do this for the EventStream as it
+                                # turns out that e.g. account data streams share
+                                # their "current token" with each other, meaning
+                                # that it is *not* safe to send a POSITION.
+                                logger.info(
+                                    "Sending position: %s -> %s",
+                                    stream.NAME,
+                                    current_token,
+                                )
+                                self.command_handler.send_command(
+                                    PositionCommand(
+                                        stream.NAME,
+                                        self._instance_name,
+                                        last_token,
+                                        current_token,
+                                    )
+                                )
+                            continue
+
                         # Some streams return multiple rows with the same stream IDs,
                         # we need to make sure they get sent out in batches. We do
                         # this by setting the current token to all but the last of
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index d7e40aaa8b..3d8da48f2d 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -524,6 +524,16 @@ class MultiWriterIdGenerator:
 
         heapq.heappush(self._known_persisted_positions, new_id)
 
+        # If we're a writer and we don't have any active writes we update our
+        # current position to the latest position seen. This allows the instance
+        # to report a recent position when asked, rather than a potentially old
+        # one (if this instance hasn't written anything for a while).
+        our_current_position = self._current_positions.get(self._instance_name)
+        if our_current_position and not self._unfinished_ids:
+            self._current_positions[self._instance_name] = max(
+                our_current_position, new_id
+            )
+
         # We move the current min position up if the minimum current positions
         # of all instances is higher (since by definition all positions less
         # that that have been persisted).
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index ff2d038ad2..4386b6101e 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -126,6 +126,8 @@ class PostgresSequenceGenerator(SequenceGenerator):
         if max_stream_id > last_value:
             logger.warning(
                 "Postgres sequence %s is behind table %s: %d < %d",
+                self._sequence_name,
+                table,
                 last_value,
                 max_stream_id,
             )
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 392b08832b..cc0612cf65 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -199,10 +199,17 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
-        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+        # The first ID gen will notice that it can advance its token to 7 as it
+        # has no in progress writes...
+        self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7})
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
 
+        # ... but the second ID gen doesn't know that.
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
 
@@ -211,7 +218,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.assertEqual(stream_id, 8)
 
                 self.assertEqual(
-                    first_id_gen.get_positions(), {"first": 3, "second": 7}
+                    first_id_gen.get_positions(), {"first": 7, "second": 7}
                 )
 
         self.get_success(_get_next_async())
@@ -279,7 +286,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self._insert_row_with_id("first", 3)
         self._insert_row_with_id("second", 5)
 
-        id_gen = self._create_id_generator("first", writers=["first", "second"])
+        id_gen = self._create_id_generator("worker", writers=["first", "second"])
 
         self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
 
@@ -319,14 +326,14 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         id_gen = self._create_id_generator("first", writers=["first", "second"])
 
-        self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+        self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
 
-        self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 5)
 
         async def _get_next_async():
             async with id_gen.get_next() as stream_id:
                 self.assertEqual(stream_id, 6)
-                self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+                self.assertEqual(id_gen.get_persisted_upto_position(), 5)
 
         self.get_success(_get_next_async())
 
@@ -388,7 +395,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self._insert_row_with_id("second", 5)
 
         # Initial config has two writers
-        id_gen = self._create_id_generator("first", writers=["first", "second"])
+        id_gen = self._create_id_generator("worker", writers=["first", "second"])
         self.assertEqual(id_gen.get_persisted_upto_position(), 3)
         self.assertEqual(id_gen.get_current_token_for_writer("first"), 3)
         self.assertEqual(id_gen.get_current_token_for_writer("second"), 5)
@@ -568,7 +575,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.get_success(_get_next_async2())
 
-        self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+        self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2})
         self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
         self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)