summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16473.bugfix1
-rw-r--r--docs/development/synapse_architecture/streams.md13
-rw-r--r--synapse/replication/http/_base.py2
-rw-r--r--synapse/replication/tcp/streams/_base.py129
-rw-r--r--synapse/replication/tcp/streams/events.py8
-rw-r--r--synapse/replication/tcp/streams/federation.py15
-rw-r--r--synapse/replication/tcp/streams/partial_state.py10
-rw-r--r--synapse/storage/util/id_generators.py68
-rw-r--r--tests/storage/test_id_generators.py136
9 files changed, 305 insertions, 77 deletions
diff --git a/changelog.d/16473.bugfix b/changelog.d/16473.bugfix
new file mode 100644
index 0000000000..4f4a0380cd
--- /dev/null
+++ b/changelog.d/16473.bugfix
@@ -0,0 +1 @@
+Fix a long-standing, exceedingly rare edge case where the first event persisted by a new event persister worker might not be sent down `/sync`.
diff --git a/docs/development/synapse_architecture/streams.md b/docs/development/synapse_architecture/streams.md
index bee0b8a8c0..67d92acfa1 100644
--- a/docs/development/synapse_architecture/streams.md
+++ b/docs/development/synapse_architecture/streams.md
@@ -51,17 +51,24 @@ will be inserted with that ID.
 
 For any given stream reader (including writers themselves), we may define a per-writer current stream ID:
 
-> The current stream ID _for a writer W_ is the largest stream ID such that
+> A current stream ID _for a writer W_ is the largest stream ID such that
 > all transactions added by W with equal or smaller ID have completed.
 
 Similarly, there is a "linear" notion of current stream ID:
 
-> The "linear" current stream ID is the largest stream ID such that
+> A "linear" current stream ID is the largest stream ID such that
 > all facts (added by any writer) with equal or smaller ID have completed.
 
 Because different stream readers A and B learn about new facts at different times, A and B may disagree about current stream IDs.
 Put differently: we should think of stream readers as being independent of each other, proceeding through a stream of facts at different rates.
 
+The above definition does not give a unique current stream ID, in fact there can
+be a range of current stream IDs. Synapse uses both the minimum and maximum IDs
+for different purposes. Most often the maximum is used, as its generally
+beneficial for workers to advance their IDs as soon as possible. However, the
+minimum is used in situations where e.g. another worker is going to wait until
+the stream advances past a position.
+
 **NB.** For both senses of "current", that if a writer opens a transaction that never completes, the current stream ID will never advance beyond that writer's last written stream ID.
 
 For single-writer streams, the per-writer current ID and the linear current ID are the same.
@@ -114,7 +121,7 @@ Writers need to track:
  - track their current position (i.e. its own per-writer stream ID).
  - their facts currently awaiting completion.
 
-At startup, 
+At startup,
  - the current position of that writer can be found by querying the database (which suggests that facts need to be written to the database atomically, in a transaction); and
  - there are no facts awaiting completion.
 
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 63cf24a14d..7476839db5 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -238,7 +238,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
                     data[_STREAM_POSITION_KEY] = {
                         "streams": {
-                            stream.NAME: stream.current_token(local_instance_name)
+                            stream.NAME: stream.minimal_local_current_token()
                             for stream in streams
                         },
                         "instance_name": local_instance_name,
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index c6088a0f99..5c4d228f3d 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -33,6 +33,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 
 logger = logging.getLogger(__name__)
 
@@ -107,22 +108,10 @@ class Stream:
     def __init__(
         self,
         local_instance_name: str,
-        current_token_function: Callable[[str], Token],
         update_function: UpdateFunction,
     ):
         """Instantiate a Stream
 
-        `current_token_function` and `update_function` are callbacks which
-        should be implemented by subclasses.
-
-        `current_token_function` takes an instance name, which is a writer to
-        the stream, and returns the position in the stream of the writer (as
-        viewed from the current process). On the writer process this is where
-        the writer has successfully written up to, whereas on other processes
-        this is the position which we have received updates up to over
-        replication. (Note that most streams have a single writer and so their
-        implementations ignore the instance name passed in).
-
         `update_function` is called to get updates for this stream between a
         pair of stream tokens. See the `UpdateFunction` type definition for more
         info.
@@ -133,12 +122,28 @@ class Stream:
             update_function: callback go get stream updates, as above
         """
         self.local_instance_name = local_instance_name
-        self.current_token = current_token_function
         self.update_function = update_function
 
         # The token from which we last asked for updates
         self.last_token = self.current_token(self.local_instance_name)
 
+    def current_token(self, instance_name: str) -> Token:
+        """This takes an instance name, which is a writer to
+        the stream, and returns the position in the stream of the writer (as
+        viewed from the current process).
+        """
+        # We can't make this an abstract class as it makes mypy unhappy.
+        raise NotImplementedError()
+
+    def minimal_local_current_token(self) -> Token:
+        """Tries to return a minimal current token for the local instance,
+        i.e. for writers this would be the last successful write.
+
+        If local instance is not a writer (or has written yet) then falls back
+        to returning the normal "current token".
+        """
+        raise NotImplementedError()
+
     def discard_updates_and_advance(self) -> None:
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
@@ -190,6 +195,25 @@ class Stream:
         return updates, upto_token, limited
 
 
+class _StreamFromIdGen(Stream):
+    """Helper class for simple streams that use a stream ID generator"""
+
+    def __init__(
+        self,
+        local_instance_name: str,
+        update_function: UpdateFunction,
+        stream_id_gen: "AbstractStreamIdGenerator",
+    ):
+        self._stream_id_gen = stream_id_gen
+        super().__init__(local_instance_name, update_function)
+
+    def current_token(self, instance_name: str) -> Token:
+        return self._stream_id_gen.get_current_token_for_writer(instance_name)
+
+    def minimal_local_current_token(self) -> Token:
+        return self._stream_id_gen.get_minimal_local_current_token()
+
+
 def current_token_without_instance(
     current_token: Callable[[], int]
 ) -> Callable[[str], int]:
@@ -242,17 +266,21 @@ class BackfillStream(Stream):
         self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            self._current_token,
             self.store.get_all_new_backfill_event_rows,
         )
 
-    def _current_token(self, instance_name: str) -> int:
+    def current_token(self, instance_name: str) -> Token:
         # The backfill stream over replication operates on *positive* numbers,
         # which means we need to negate it.
         return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
 
+    def minimal_local_current_token(self) -> Token:
+        # The backfill stream over replication operates on *positive* numbers,
+        # which means we need to negate it.
+        return -self.store._backfill_id_gen.get_minimal_local_current_token()
 
-class PresenceStream(Stream):
+
+class PresenceStream(_StreamFromIdGen):
     @attr.s(slots=True, frozen=True, auto_attribs=True)
     class PresenceStreamRow:
         user_id: str
@@ -283,9 +311,7 @@ class PresenceStream(Stream):
             update_function = make_http_update_function(hs, self.NAME)
 
         super().__init__(
-            hs.get_instance_name(),
-            current_token_without_instance(store.get_current_presence_token),
-            update_function,
+            hs.get_instance_name(), update_function, store._presence_id_gen
         )
 
 
@@ -305,13 +331,18 @@ class PresenceFederationStream(Stream):
     ROW_TYPE = PresenceFederationStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        federation_queue = hs.get_presence_handler().get_federation_queue()
+        self._federation_queue = hs.get_presence_handler().get_federation_queue()
         super().__init__(
             hs.get_instance_name(),
-            federation_queue.get_current_token,
-            federation_queue.get_replication_rows,
+            self._federation_queue.get_replication_rows,
         )
 
+    def current_token(self, instance_name: str) -> Token:
+        return self._federation_queue.get_current_token(instance_name)
+
+    def minimal_local_current_token(self) -> Token:
+        return self._federation_queue.get_current_token(self.local_instance_name)
+
 
 class TypingStream(Stream):
     @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -341,20 +372,25 @@ class TypingStream(Stream):
             update_function: Callable[
                 [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
             ] = typing_writer_handler.get_all_typing_updates
-            current_token_function = typing_writer_handler.get_current_token
+            self.current_token_function = typing_writer_handler.get_current_token
         else:
             # Query the typing writer process
             update_function = make_http_update_function(hs, self.NAME)
-            current_token_function = hs.get_typing_handler().get_current_token
+            self.current_token_function = hs.get_typing_handler().get_current_token
 
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(current_token_function),
             update_function,
         )
 
+    def current_token(self, instance_name: str) -> Token:
+        return self.current_token_function()
+
+    def minimal_local_current_token(self) -> Token:
+        return self.current_token_function()
 
-class ReceiptsStream(Stream):
+
+class ReceiptsStream(_StreamFromIdGen):
     @attr.s(slots=True, frozen=True, auto_attribs=True)
     class ReceiptsStreamRow:
         room_id: str
@@ -371,12 +407,12 @@ class ReceiptsStream(Stream):
         store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(store.get_max_receipt_stream_id),
             store.get_all_updated_receipts,
+            store._receipts_id_gen,
         )
 
 
-class PushRulesStream(Stream):
+class PushRulesStream(_StreamFromIdGen):
     """A user has changed their push rules"""
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -387,20 +423,16 @@ class PushRulesStream(Stream):
     ROW_TYPE = PushRulesStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastores().main
+        store = hs.get_datastores().main
 
         super().__init__(
             hs.get_instance_name(),
-            self._current_token,
-            self.store.get_all_push_rule_updates,
+            store.get_all_push_rule_updates,
+            store._push_rules_stream_id_gen,
         )
 
-    def _current_token(self, instance_name: str) -> int:
-        push_rules_token = self.store.get_max_push_rules_stream_id()
-        return push_rules_token
-
 
-class PushersStream(Stream):
+class PushersStream(_StreamFromIdGen):
     """A user has added/changed/removed a pusher"""
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -418,8 +450,8 @@ class PushersStream(Stream):
 
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(store.get_pushers_stream_token),
             store.get_all_updated_pushers_rows,
+            store._pushers_id_gen,
         )
 
 
@@ -447,15 +479,20 @@ class CachesStream(Stream):
     ROW_TYPE = CachesStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastores().main
+        self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            store.get_cache_stream_token_for_writer,
-            store.get_all_updated_caches,
+            self.store.get_all_updated_caches,
         )
 
+    def current_token(self, instance_name: str) -> Token:
+        return self.store.get_cache_stream_token_for_writer(instance_name)
+
+    def minimal_local_current_token(self) -> Token:
+        return self.current_token(self.local_instance_name)
+
 
-class DeviceListsStream(Stream):
+class DeviceListsStream(_StreamFromIdGen):
     """Either a user has updated their devices or a remote server needs to be
     told about a device update.
     """
@@ -473,8 +510,8 @@ class DeviceListsStream(Stream):
         self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(self.store.get_device_stream_token),
             self._update_function,
+            self.store._device_list_id_gen,
         )
 
     async def _update_function(
@@ -525,7 +562,7 @@ class DeviceListsStream(Stream):
         return updates, upper_limit_token, devices_limited or signatures_limited
 
 
-class ToDeviceStream(Stream):
+class ToDeviceStream(_StreamFromIdGen):
     """New to_device messages for a client"""
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -539,12 +576,12 @@ class ToDeviceStream(Stream):
         store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(store.get_to_device_stream_token),
             store.get_all_new_device_messages,
+            store._device_inbox_id_gen,
         )
 
 
-class AccountDataStream(Stream):
+class AccountDataStream(_StreamFromIdGen):
     """Global or per room account data was changed"""
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -560,8 +597,8 @@ class AccountDataStream(Stream):
         self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(self.store.get_max_account_data_stream_id),
             self._update_function,
+            self.store._account_data_id_gen,
         )
 
     async def _update_function(
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index da6d948e1b..38823113d8 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
 import attr
 
 from synapse.replication.tcp.streams._base import (
-    Stream,
     StreamRow,
     StreamUpdateResult,
     Token,
+    _StreamFromIdGen,
 )
 
 if TYPE_CHECKING:
@@ -139,7 +139,7 @@ _EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
 TypeToRow = {Row.TypeId: Row for Row in _EventRows}
 
 
-class EventsStream(Stream):
+class EventsStream(_StreamFromIdGen):
     """We received a new event, or an event went from being an outlier to not"""
 
     NAME = "events"
@@ -147,9 +147,7 @@ class EventsStream(Stream):
     def __init__(self, hs: "HomeServer"):
         self._store = hs.get_datastores().main
         super().__init__(
-            hs.get_instance_name(),
-            self._store._stream_id_gen.get_current_token_for_writer,
-            self._update_function,
+            hs.get_instance_name(), self._update_function, self._store._stream_id_gen
         )
 
     async def _update_function(
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 4046bdec69..7f5af5852c 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -18,6 +18,7 @@ import attr
 
 from synapse.replication.tcp.streams._base import (
     Stream,
+    Token,
     current_token_without_instance,
     make_http_update_function,
 )
@@ -47,7 +48,7 @@ class FederationStream(Stream):
             # will be a real FederationSender, which has stubs for current_token and
             # get_replication_rows.)
             federation_sender = hs.get_federation_sender()
-            current_token = current_token_without_instance(
+            self.current_token_func = current_token_without_instance(
                 federation_sender.get_current_token
             )
             update_function: Callable[
@@ -57,15 +58,21 @@ class FederationStream(Stream):
         elif hs.should_send_federation():
             # federation sender: Query master process
             update_function = make_http_update_function(hs, self.NAME)
-            current_token = self._stub_current_token
+            self.current_token_func = self._stub_current_token
 
         else:
             # other worker: stub out the update function (we're not interested in
             # any updates so when we get a POSITION we do nothing)
             update_function = self._stub_update_function
-            current_token = self._stub_current_token
+            self.current_token_func = self._stub_current_token
 
-        super().__init__(hs.get_instance_name(), current_token, update_function)
+        super().__init__(hs.get_instance_name(), update_function)
+
+    def current_token(self, instance_name: str) -> Token:
+        return self.current_token_func(instance_name)
+
+    def minimal_local_current_token(self) -> Token:
+        return self.current_token(self.local_instance_name)
 
     @staticmethod
     def _stub_current_token(instance_name: str) -> int:
diff --git a/synapse/replication/tcp/streams/partial_state.py b/synapse/replication/tcp/streams/partial_state.py
index a8ce5ffd72..ad181d7e93 100644
--- a/synapse/replication/tcp/streams/partial_state.py
+++ b/synapse/replication/tcp/streams/partial_state.py
@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING
 
 import attr
 
-from synapse.replication.tcp.streams import Stream
+from synapse.replication.tcp.streams._base import _StreamFromIdGen
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -27,7 +27,7 @@ class UnPartialStatedRoomStreamRow:
     room_id: str
 
 
-class UnPartialStatedRoomStream(Stream):
+class UnPartialStatedRoomStream(_StreamFromIdGen):
     """
     Stream to notify about rooms becoming un-partial-stated;
     that is, when the background sync finishes such that we now have full state for
@@ -41,8 +41,8 @@ class UnPartialStatedRoomStream(Stream):
         store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            store.get_un_partial_stated_rooms_token,
             store.get_un_partial_stated_rooms_from_stream,
+            store._un_partial_stated_rooms_stream_id_gen,
         )
 
 
@@ -56,7 +56,7 @@ class UnPartialStatedEventStreamRow:
     rejection_status_changed: bool
 
 
-class UnPartialStatedEventStream(Stream):
+class UnPartialStatedEventStream(_StreamFromIdGen):
     """
     Stream to notify about events becoming un-partial-stated.
     """
@@ -68,6 +68,6 @@ class UnPartialStatedEventStream(Stream):
         store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            store.get_un_partial_stated_events_token,
             store.get_un_partial_stated_events_from_stream,
+            store._un_partial_stated_events_stream_id_gen,
         )
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index d2c874b9a8..9c3eafb562 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -134,6 +134,15 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     @abc.abstractmethod
+    def get_minimal_local_current_token(self) -> int:
+        """Tries to return a minimal current token for the local instance,
+        i.e. for writers this would be the last successful write.
+
+        If local instance is not a writer (or has written yet) then falls back
+        to returning the normal "current token".
+        """
+
+    @abc.abstractmethod
     def get_next(self) -> AsyncContextManager[int]:
         """
         Usage:
@@ -312,6 +321,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
     def get_current_token_for_writer(self, instance_name: str) -> int:
         return self.get_current_token()
 
+    def get_minimal_local_current_token(self) -> int:
+        return self.get_current_token()
+
 
 class MultiWriterIdGenerator(AbstractStreamIdGenerator):
     """Generates and tracks stream IDs for a stream with multiple writers.
@@ -408,6 +420,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         # The maximum stream ID that we have seen been allocated across any writer.
         self._max_seen_allocated_stream_id = 1
 
+        # The maximum position of the local instance. This can be higher than
+        # the corresponding position in `current_positions` table when there are
+        # no active writes in progress.
+        self._max_position_of_local_instance = self._max_seen_allocated_stream_id
+
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
         # We check that the table and sequence haven't diverged.
@@ -427,6 +444,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             self._current_positions.values(), default=1
         )
 
+        # For the case where `stream_positions` is not up to date,
+        # `_persisted_upto_position` may be higher.
+        self._max_seen_allocated_stream_id = max(
+            self._max_seen_allocated_stream_id, self._persisted_upto_position
+        )
+
+        # Bump our local maximum position now that we've loaded things from the
+        # DB.
+        self._max_position_of_local_instance = self._max_seen_allocated_stream_id
+
         if not writers:
             # If there have been no explicit writers given then any instance can
             # write to the stream. In which case, let's pre-seed our own
@@ -545,6 +572,14 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
                     if instance == self._instance_name:
                         self._current_positions[instance] = stream_id
 
+        if self._writers:
+            # If we have explicit writers then make sure that each instance has
+            # a position.
+            for writer in self._writers:
+                self._current_positions.setdefault(
+                    writer, self._persisted_upto_position
+                )
+
         cur.close()
 
     def _load_next_id_txn(self, txn: Cursor) -> int:
@@ -688,6 +723,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             if new_cur:
                 curr = self._current_positions.get(self._instance_name, 0)
                 self._current_positions[self._instance_name] = max(curr, new_cur)
+                self._max_position_of_local_instance = max(
+                    curr, new_cur, self._max_position_of_local_instance
+                )
 
             self._add_persisted_position(next_id)
 
@@ -702,10 +740,26 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         # persisted up to position. This stops Synapse from doing a full table
         # scan when a new writer announces itself over replication.
         with self._lock:
-            return self._return_factor * self._current_positions.get(
+            if self._instance_name == instance_name:
+                return self._return_factor * self._max_position_of_local_instance
+
+            pos = self._current_positions.get(
                 instance_name, self._persisted_upto_position
             )
 
+            # We want to return the maximum "current token" that we can for a
+            # writer, this helps ensure that streams progress as fast as
+            # possible.
+            pos = max(pos, self._persisted_upto_position)
+
+            return self._return_factor * pos
+
+    def get_minimal_local_current_token(self) -> int:
+        with self._lock:
+            return self._return_factor * self._current_positions.get(
+                self._instance_name, self._persisted_upto_position
+            )
+
     def get_positions(self) -> Dict[str, int]:
         """Get a copy of the current positon map.
 
@@ -774,6 +828,18 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
 
         self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
 
+        # Advance our local max position.
+        self._max_position_of_local_instance = max(
+            self._max_position_of_local_instance, self._persisted_upto_position
+        )
+
+        if not self._unfinished_ids and not self._in_flight_fetches:
+            # If we don't have anything in flight, it's safe to advance to the
+            # max seen stream ID.
+            self._max_position_of_local_instance = max(
+                self._max_seen_allocated_stream_id, self._max_position_of_local_instance
+            )
+
         # We now iterate through the seen positions, discarding those that are
         # less than the current min positions, and incrementing the min position
         # if its exactly one greater.
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 9174fb0964..fd53b0644c 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -259,8 +259,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         id_gen = self._create_id_generator()
 
-        # The table is empty so we expect an empty map for positions
-        self.assertEqual(id_gen.get_positions(), {})
+        # The table is empty so we expect the map for positions to have a dummy
+        # minimum value.
+        self.assertEqual(id_gen.get_positions(), {"master": 1})
 
     def test_single_instance(self) -> None:
         """Test that reads and writes from a single process are handled
@@ -349,15 +350,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
-        # The first ID gen will notice that it can advance its token to 7 as it
-        # has no in progress writes...
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
 
-        # ... but the second ID gen doesn't know that.
         self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
 
         # Try allocating a new ID gen and check that we only see position
@@ -398,6 +396,56 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         second_id_gen.advance("first", 8)
         self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
 
+    def test_multi_instance_empty_row(self) -> None:
+        """Test that reads and writes from multiple processes are handled
+        correctly, when one of the writers starts without any rows.
+        """
+        # Insert some rows for two out of three of the ID gens.
+        self._insert_rows("first", 3)
+        self._insert_rows("second", 4)
+
+        first_id_gen = self._create_id_generator(
+            "first", writers=["first", "second", "third"]
+        )
+        second_id_gen = self._create_id_generator(
+            "second", writers=["first", "second", "third"]
+        )
+        third_id_gen = self._create_id_generator(
+            "third", writers=["first", "second", "third"]
+        )
+
+        self.assertEqual(
+            first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+        )
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7)
+
+        self.assertEqual(
+            second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+        )
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7)
+
+        # Try allocating a new ID gen and check that we only see position
+        # advanced after we leave the context manager.
+
+        async def _get_next_async() -> None:
+            async with third_id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 8)
+
+                self.assertEqual(
+                    third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+                )
+                self.assertEqual(third_id_gen.get_persisted_upto_position(), 7)
+
+        self.get_success(_get_next_async())
+
+        self.assertEqual(
+            third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
+        )
+
     def test_get_next_txn(self) -> None:
         """Test that the `get_next_txn` function works correctly."""
 
@@ -600,6 +648,70 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         with self.assertRaises(IncorrectDatabaseSetup):
             self._create_id_generator("first")
 
+    def test_minimal_local_token(self) -> None:
+        self._insert_rows("first", 3)
+        self._insert_rows("second", 4)
+
+        first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+        second_id_gen = self._create_id_generator("second", writers=["first", "second"])
+
+        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
+
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(second_id_gen.get_minimal_local_current_token(), 7)
+
+    def test_current_token_gap(self) -> None:
+        """Test that getting the current token for a writer returns the maximal
+        token when there are no writes.
+        """
+        self._insert_rows("first", 3)
+        self._insert_rows("second", 4)
+
+        first_id_gen = self._create_id_generator(
+            "first", writers=["first", "second", "third"]
+        )
+        second_id_gen = self._create_id_generator(
+            "second", writers=["first", "second", "third"]
+        )
+
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(second_id_gen.get_current_token(), 7)
+
+        # Check that the first ID gen advancing causes the second ID gen to
+        # advance (as the second ID gen has nothing in flight).
+
+        async def _get_next_async() -> None:
+            async with first_id_gen.get_next_mult(2):
+                pass
+
+        self.get_success(_get_next_async())
+        second_id_gen.advance("first", 9)
+
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 9)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
+        self.assertEqual(second_id_gen.get_current_token(), 7)
+
+        # Check that the first ID gen advancing doesn't advance the second ID
+        # gen when the second ID gen has stuff in flight.
+        self.get_success(_get_next_async())
+
+        ctxmgr = second_id_gen.get_next()
+        self.get_success(ctxmgr.__aenter__())
+
+        second_id_gen.advance("first", 11)
+
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
+        self.assertEqual(second_id_gen.get_current_token(), 7)
+
+        self.get_success(ctxmgr.__aexit__(None, None, None))
+
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 12)
+        self.assertEqual(second_id_gen.get_current_token(), 7)
+
 
 class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
     """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
@@ -712,8 +824,8 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.get_success(_get_next_async())
 
-        self.assertEqual(id_gen_1.get_positions(), {"first": -1})
-        self.assertEqual(id_gen_2.get_positions(), {"first": -1})
+        self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -1})
+        self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -1})
         self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
 
@@ -822,11 +934,11 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
-        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
-        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
         self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
 
         self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
         self.assertEqual(second_id_gen.get_persisted_upto_position(), 7)