diff --git a/changelog.d/8499.misc b/changelog.d/8499.misc
new file mode 100644
index 0000000000..237cb3b311
--- /dev/null
+++ b/changelog.d/8499.misc
@@ -0,0 +1 @@
+Allow events to be sent to clients sooner when using sharded event persisters.
diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index db318baa9d..ad145439b4 100644
--- a/docs/tcp_replication.md
+++ b/docs/tcp_replication.md
@@ -15,7 +15,7 @@ example flow would be (where '>' indicates master to worker and
> SERVER example.com
< REPLICATE
- > POSITION events master 53
+ > POSITION events master 53 53
> RDATA events master 54 ["$foo1:bar.com", ...]
> RDATA events master 55 ["$foo4:bar.com", ...]
@@ -138,9 +138,9 @@ the wire:
< NAME synapse.app.appservice
< PING 1490197665618
< REPLICATE
- > POSITION events master 1
- > POSITION backfill master 1
- > POSITION caches master 1
+ > POSITION events master 1 1
+ > POSITION backfill master 1 1
+ > POSITION caches master 1 1
> RDATA caches master 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513]
> RDATA events master 14 ["$149019767112vOHxz:localhost:8823",
"!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]
@@ -185,6 +185,11 @@ client (C):
updates via HTTP API, rather than via the DB, then processes should make the
request to the appropriate process.
+ Two positions are included, the "new" position and the last position sent respectively.
+ This allows servers to tell instances that the positions have advanced but no
+ data has been written, without clients needlessly checking to see if they
+ have missed any updates.
+
#### ERROR (S, C)
There was an error
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e165429cad..e27ee216f0 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -191,6 +191,10 @@ class ReplicationDataHandler:
async def on_position(self, stream_name: str, instance_name: str, token: int):
self.store.process_replication_rows(stream_name, instance_name, token, [])
+ # We poke the generic "replication" notifier to wake anything up that
+ # may be streaming.
+ self.notifier.notify_replication()
+
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 8cd47770c1..ac532ed588 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -141,15 +141,23 @@ class RdataCommand(Command):
class PositionCommand(Command):
- """Sent by the server to tell the client the stream position without
- needing to send an RDATA.
+ """Sent by an instance to tell others the stream position without needing to
+ send an RDATA.
+
+ Two tokens are sent, the new position and the last position sent by the
+ instance (in an RDATA or other POSITION). The tokens are chosen so that *no*
+ rows were written by the instance between the `prev_token` and `new_token`.
+ (If an instance hasn't sent a position before then the new position can be
+ used for both.)
Format::
- POSITION <stream_name> <instance_name> <token>
+ POSITION <stream_name> <instance_name> <prev_token> <new_token>
- On receipt of a POSITION command clients should check if they have missed
- any updates, and if so then fetch them out of band.
+ On receipt of a POSITION command instances should check if they have missed
+ any updates, and if so then fetch them out of band. Instances can check this
+ by comparing their view of the current token for the sending instance with
+ the included `prev_token`.
The `<instance_name>` is the process that sent the command and is the source
of the stream.
@@ -157,18 +165,26 @@ class PositionCommand(Command):
NAME = "POSITION"
- def __init__(self, stream_name, instance_name, token):
+ def __init__(self, stream_name, instance_name, prev_token, new_token):
self.stream_name = stream_name
self.instance_name = instance_name
- self.token = token
+ self.prev_token = prev_token
+ self.new_token = new_token
@classmethod
def from_line(cls, line):
- stream_name, instance_name, token = line.split(" ", 2)
- return cls(stream_name, instance_name, int(token))
+ stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
+ return cls(stream_name, instance_name, int(prev_token), int(new_token))
def to_line(self):
- return " ".join((self.stream_name, self.instance_name, str(self.token)))
+ return " ".join(
+ (
+ self.stream_name,
+ self.instance_name,
+ str(self.prev_token),
+ str(self.new_token),
+ )
+ )
class ErrorCommand(_SimpleCommand):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index e92da7b263..95e5502bf2 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -101,8 +101,9 @@ class ReplicationCommandHandler:
self._streams_to_replicate = [] # type: List[Stream]
for stream in self._streams.values():
- if stream.NAME == CachesStream.NAME:
- # All workers can write to the cache invalidation stream.
+ if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME:
+ # All workers can write to the cache invalidation stream when
+ # using redis.
self._streams_to_replicate.append(stream)
continue
@@ -313,11 +314,14 @@ class ReplicationCommandHandler:
# We respond with current position of all streams this instance
# replicates.
for stream in self.get_streams_to_replicate():
+ # Note that we use the current token as the prev token here (rather
+ # than stream.last_token), as we can't be sure that there have been
+ # no rows written between last token and the current token (since we
+ # might be racing with the replication sending bg process).
+ current_token = stream.current_token(self._instance_name)
self.send_command(
PositionCommand(
- stream.NAME,
- self._instance_name,
- stream.current_token(self._instance_name),
+ stream.NAME, self._instance_name, current_token, current_token,
)
)
@@ -511,16 +515,16 @@ class ReplicationCommandHandler:
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
# between then and now.
- missing_updates = cmd.token != current_token
+ missing_updates = cmd.prev_token != current_token
while missing_updates:
logger.info(
"Fetching replication rows for '%s' between %i and %i",
stream_name,
current_token,
- cmd.token,
+ cmd.new_token,
)
(updates, current_token, missing_updates) = await stream.get_updates_since(
- cmd.instance_name, current_token, cmd.token
+ cmd.instance_name, current_token, cmd.new_token
)
# TODO: add some tests for this
@@ -536,11 +540,11 @@ class ReplicationCommandHandler:
[stream.parse_row(row) for row in rows],
)
- logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+ logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token)
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(
- cmd.stream_name, cmd.instance_name, cmd.token
+ cmd.stream_name, cmd.instance_name, cmd.new_token
)
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 687984e7a8..666c13fdb7 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -23,7 +23,9 @@ from prometheus_client import Counter
from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.tcp.commands import PositionCommand
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
+from synapse.replication.tcp.streams import EventsStream
from synapse.util.metrics import Measure
stream_updates_counter = Counter(
@@ -84,6 +86,23 @@ class ReplicationStreamer:
# Set of streams to replicate.
self.streams = self.command_handler.get_streams_to_replicate()
+ # If we have streams then we must have redis enabled or on master
+ assert (
+ not self.streams
+ or hs.config.redis.redis_enabled
+ or not hs.config.worker.worker_app
+ )
+
+ # If we are replicating an event stream we want to periodically check if
+ # we should send updated POSITIONs. We do this as a looping call rather
+ # explicitly poking when the position advances (without new data to
+ # replicate) to reduce replication traffic (otherwise each writer would
+ # likely send a POSITION for each new event received over replication).
+ #
+ # Note that if the position hasn't advanced then we won't send anything.
+ if any(EventsStream.NAME == s.NAME for s in self.streams):
+ self.clock.looping_call(self.on_notifier_poke, 1000)
+
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
connections if there are.
@@ -91,7 +110,7 @@ class ReplicationStreamer:
This should get called each time new data is available, even if it
is currently being executed, so that nothing gets missed
"""
- if not self.command_handler.connected():
+ if not self.command_handler.connected() or not self.streams:
# Don't bother if nothing is listening. We still need to advance
# the stream tokens otherwise they'll fall behind forever
for stream in self.streams:
@@ -136,6 +155,8 @@ class ReplicationStreamer:
self._replication_torture_level / 1000.0
)
+ last_token = stream.last_token
+
logger.debug(
"Getting stream: %s: %s -> %s",
stream.NAME,
@@ -159,6 +180,30 @@ class ReplicationStreamer:
)
stream_updates_counter.labels(stream.NAME).inc(len(updates))
+ else:
+ # The token has advanced but there is no data to
+ # send, so we send a `POSITION` to inform other
+ # workers of the updated position.
+ if stream.NAME == EventsStream.NAME:
+ # XXX: We only do this for the EventStream as it
+ # turns out that e.g. account data streams share
+ # their "current token" with each other, meaning
+ # that it is *not* safe to send a POSITION.
+ logger.info(
+ "Sending position: %s -> %s",
+ stream.NAME,
+ current_token,
+ )
+ self.command_handler.send_command(
+ PositionCommand(
+ stream.NAME,
+ self._instance_name,
+ last_token,
+ current_token,
+ )
+ )
+ continue
+
# Some streams return multiple rows with the same stream IDs,
# we need to make sure they get sent out in batches. We do
# this by setting the current token to all but the last of
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index d7e40aaa8b..3d8da48f2d 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -524,6 +524,16 @@ class MultiWriterIdGenerator:
heapq.heappush(self._known_persisted_positions, new_id)
+ # If we're a writer and we don't have any active writes we update our
+ # current position to the latest position seen. This allows the instance
+ # to report a recent position when asked, rather than a potentially old
+ # one (if this instance hasn't written anything for a while).
+ our_current_position = self._current_positions.get(self._instance_name)
+ if our_current_position and not self._unfinished_ids:
+ self._current_positions[self._instance_name] = max(
+ our_current_position, new_id
+ )
+
# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index ff2d038ad2..4386b6101e 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -126,6 +126,8 @@ class PostgresSequenceGenerator(SequenceGenerator):
if max_stream_id > last_value:
logger.warning(
"Postgres sequence %s is behind table %s: %d < %d",
+ self._sequence_name,
+ table,
last_value,
max_stream_id,
)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 392b08832b..cc0612cf65 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -199,10 +199,17 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
- self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+ # The first ID gen will notice that it can advance its token to 7 as it
+ # has no in progress writes...
+ self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7})
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
+ # ... but the second ID gen doesn't know that.
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -211,7 +218,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(
- first_id_gen.get_positions(), {"first": 3, "second": 7}
+ first_id_gen.get_positions(), {"first": 7, "second": 7}
)
self.get_success(_get_next_async())
@@ -279,7 +286,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
- id_gen = self._create_id_generator("first", writers=["first", "second"])
+ id_gen = self._create_id_generator("worker", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
@@ -319,14 +326,14 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator("first", writers=["first", "second"])
- self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+ self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
- self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 5)
async def _get_next_async():
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 6)
- self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 5)
self.get_success(_get_next_async())
@@ -388,7 +395,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("second", 5)
# Initial config has two writers
- id_gen = self._create_id_generator("first", writers=["first", "second"])
+ id_gen = self._create_id_generator("worker", writers=["first", "second"])
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
self.assertEqual(id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(id_gen.get_current_token_for_writer("second"), 5)
@@ -568,7 +575,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(_get_next_async2())
- self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
|