diff --git a/changelog.d/16473.bugfix b/changelog.d/16473.bugfix
new file mode 100644
index 0000000000..4f4a0380cd
--- /dev/null
+++ b/changelog.d/16473.bugfix
@@ -0,0 +1 @@
+Fix a long-standing, exceedingly rare edge case where the first event persisted by a new event persister worker might not be sent down `/sync`.
diff --git a/docs/development/synapse_architecture/streams.md b/docs/development/synapse_architecture/streams.md
index bee0b8a8c0..67d92acfa1 100644
--- a/docs/development/synapse_architecture/streams.md
+++ b/docs/development/synapse_architecture/streams.md
@@ -51,17 +51,24 @@ will be inserted with that ID.
For any given stream reader (including writers themselves), we may define a per-writer current stream ID:
-> The current stream ID _for a writer W_ is the largest stream ID such that
+> A current stream ID _for a writer W_ is the largest stream ID such that
> all transactions added by W with equal or smaller ID have completed.
Similarly, there is a "linear" notion of current stream ID:
-> The "linear" current stream ID is the largest stream ID such that
+> A "linear" current stream ID is the largest stream ID such that
> all facts (added by any writer) with equal or smaller ID have completed.
Because different stream readers A and B learn about new facts at different times, A and B may disagree about current stream IDs.
Put differently: we should think of stream readers as being independent of each other, proceeding through a stream of facts at different rates.
+The above definition does not give a unique current stream ID, in fact there can
+be a range of current stream IDs. Synapse uses both the minimum and maximum IDs
+for different purposes. Most often the maximum is used, as its generally
+beneficial for workers to advance their IDs as soon as possible. However, the
+minimum is used in situations where e.g. another worker is going to wait until
+the stream advances past a position.
+
**NB.** For both senses of "current", that if a writer opens a transaction that never completes, the current stream ID will never advance beyond that writer's last written stream ID.
For single-writer streams, the per-writer current ID and the linear current ID are the same.
@@ -114,7 +121,7 @@ Writers need to track:
- track their current position (i.e. its own per-writer stream ID).
- their facts currently awaiting completion.
-At startup,
+At startup,
- the current position of that writer can be found by querying the database (which suggests that facts need to be written to the database atomically, in a transaction); and
- there are no facts awaiting completion.
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 63cf24a14d..7476839db5 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -238,7 +238,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
data[_STREAM_POSITION_KEY] = {
"streams": {
- stream.NAME: stream.current_token(local_instance_name)
+ stream.NAME: stream.minimal_local_current_token()
for stream in streams
},
"instance_name": local_instance_name,
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index c6088a0f99..5c4d228f3d 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -33,6 +33,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
if TYPE_CHECKING:
from synapse.server import HomeServer
+ from synapse.storage.util.id_generators import AbstractStreamIdGenerator
logger = logging.getLogger(__name__)
@@ -107,22 +108,10 @@ class Stream:
def __init__(
self,
local_instance_name: str,
- current_token_function: Callable[[str], Token],
update_function: UpdateFunction,
):
"""Instantiate a Stream
- `current_token_function` and `update_function` are callbacks which
- should be implemented by subclasses.
-
- `current_token_function` takes an instance name, which is a writer to
- the stream, and returns the position in the stream of the writer (as
- viewed from the current process). On the writer process this is where
- the writer has successfully written up to, whereas on other processes
- this is the position which we have received updates up to over
- replication. (Note that most streams have a single writer and so their
- implementations ignore the instance name passed in).
-
`update_function` is called to get updates for this stream between a
pair of stream tokens. See the `UpdateFunction` type definition for more
info.
@@ -133,12 +122,28 @@ class Stream:
update_function: callback go get stream updates, as above
"""
self.local_instance_name = local_instance_name
- self.current_token = current_token_function
self.update_function = update_function
# The token from which we last asked for updates
self.last_token = self.current_token(self.local_instance_name)
+ def current_token(self, instance_name: str) -> Token:
+ """This takes an instance name, which is a writer to
+ the stream, and returns the position in the stream of the writer (as
+ viewed from the current process).
+ """
+ # We can't make this an abstract class as it makes mypy unhappy.
+ raise NotImplementedError()
+
+ def minimal_local_current_token(self) -> Token:
+ """Tries to return a minimal current token for the local instance,
+ i.e. for writers this would be the last successful write.
+
+ If local instance is not a writer (or has written yet) then falls back
+ to returning the normal "current token".
+ """
+ raise NotImplementedError()
+
def discard_updates_and_advance(self) -> None:
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
@@ -190,6 +195,25 @@ class Stream:
return updates, upto_token, limited
+class _StreamFromIdGen(Stream):
+ """Helper class for simple streams that use a stream ID generator"""
+
+ def __init__(
+ self,
+ local_instance_name: str,
+ update_function: UpdateFunction,
+ stream_id_gen: "AbstractStreamIdGenerator",
+ ):
+ self._stream_id_gen = stream_id_gen
+ super().__init__(local_instance_name, update_function)
+
+ def current_token(self, instance_name: str) -> Token:
+ return self._stream_id_gen.get_current_token_for_writer(instance_name)
+
+ def minimal_local_current_token(self) -> Token:
+ return self._stream_id_gen.get_minimal_local_current_token()
+
+
def current_token_without_instance(
current_token: Callable[[], int]
) -> Callable[[str], int]:
@@ -242,17 +266,21 @@ class BackfillStream(Stream):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- self._current_token,
self.store.get_all_new_backfill_event_rows,
)
- def _current_token(self, instance_name: str) -> int:
+ def current_token(self, instance_name: str) -> Token:
# The backfill stream over replication operates on *positive* numbers,
# which means we need to negate it.
return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
+ def minimal_local_current_token(self) -> Token:
+ # The backfill stream over replication operates on *positive* numbers,
+ # which means we need to negate it.
+ return -self.store._backfill_id_gen.get_minimal_local_current_token()
-class PresenceStream(Stream):
+
+class PresenceStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PresenceStreamRow:
user_id: str
@@ -283,9 +311,7 @@ class PresenceStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
- hs.get_instance_name(),
- current_token_without_instance(store.get_current_presence_token),
- update_function,
+ hs.get_instance_name(), update_function, store._presence_id_gen
)
@@ -305,13 +331,18 @@ class PresenceFederationStream(Stream):
ROW_TYPE = PresenceFederationStreamRow
def __init__(self, hs: "HomeServer"):
- federation_queue = hs.get_presence_handler().get_federation_queue()
+ self._federation_queue = hs.get_presence_handler().get_federation_queue()
super().__init__(
hs.get_instance_name(),
- federation_queue.get_current_token,
- federation_queue.get_replication_rows,
+ self._federation_queue.get_replication_rows,
)
+ def current_token(self, instance_name: str) -> Token:
+ return self._federation_queue.get_current_token(instance_name)
+
+ def minimal_local_current_token(self) -> Token:
+ return self._federation_queue.get_current_token(self.local_instance_name)
+
class TypingStream(Stream):
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -341,20 +372,25 @@ class TypingStream(Stream):
update_function: Callable[
[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
] = typing_writer_handler.get_all_typing_updates
- current_token_function = typing_writer_handler.get_current_token
+ self.current_token_function = typing_writer_handler.get_current_token
else:
# Query the typing writer process
update_function = make_http_update_function(hs, self.NAME)
- current_token_function = hs.get_typing_handler().get_current_token
+ self.current_token_function = hs.get_typing_handler().get_current_token
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(current_token_function),
update_function,
)
+ def current_token(self, instance_name: str) -> Token:
+ return self.current_token_function()
+
+ def minimal_local_current_token(self) -> Token:
+ return self.current_token_function()
-class ReceiptsStream(Stream):
+
+class ReceiptsStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ReceiptsStreamRow:
room_id: str
@@ -371,12 +407,12 @@ class ReceiptsStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(store.get_max_receipt_stream_id),
store.get_all_updated_receipts,
+ store._receipts_id_gen,
)
-class PushRulesStream(Stream):
+class PushRulesStream(_StreamFromIdGen):
"""A user has changed their push rules"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -387,20 +423,16 @@ class PushRulesStream(Stream):
ROW_TYPE = PushRulesStreamRow
def __init__(self, hs: "HomeServer"):
- self.store = hs.get_datastores().main
+ store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- self._current_token,
- self.store.get_all_push_rule_updates,
+ store.get_all_push_rule_updates,
+ store._push_rules_stream_id_gen,
)
- def _current_token(self, instance_name: str) -> int:
- push_rules_token = self.store.get_max_push_rules_stream_id()
- return push_rules_token
-
-class PushersStream(Stream):
+class PushersStream(_StreamFromIdGen):
"""A user has added/changed/removed a pusher"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -418,8 +450,8 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(store.get_pushers_stream_token),
store.get_all_updated_pushers_rows,
+ store._pushers_id_gen,
)
@@ -447,15 +479,20 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs: "HomeServer"):
- store = hs.get_datastores().main
+ self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- store.get_cache_stream_token_for_writer,
- store.get_all_updated_caches,
+ self.store.get_all_updated_caches,
)
+ def current_token(self, instance_name: str) -> Token:
+ return self.store.get_cache_stream_token_for_writer(instance_name)
+
+ def minimal_local_current_token(self) -> Token:
+ return self.current_token(self.local_instance_name)
+
-class DeviceListsStream(Stream):
+class DeviceListsStream(_StreamFromIdGen):
"""Either a user has updated their devices or a remote server needs to be
told about a device update.
"""
@@ -473,8 +510,8 @@ class DeviceListsStream(Stream):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(self.store.get_device_stream_token),
self._update_function,
+ self.store._device_list_id_gen,
)
async def _update_function(
@@ -525,7 +562,7 @@ class DeviceListsStream(Stream):
return updates, upper_limit_token, devices_limited or signatures_limited
-class ToDeviceStream(Stream):
+class ToDeviceStream(_StreamFromIdGen):
"""New to_device messages for a client"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -539,12 +576,12 @@ class ToDeviceStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(store.get_to_device_stream_token),
store.get_all_new_device_messages,
+ store._device_inbox_id_gen,
)
-class AccountDataStream(Stream):
+class AccountDataStream(_StreamFromIdGen):
"""Global or per room account data was changed"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -560,8 +597,8 @@ class AccountDataStream(Stream):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(self.store.get_max_account_data_stream_id),
self._update_function,
+ self.store._account_data_id_gen,
)
async def _update_function(
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index da6d948e1b..38823113d8 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
import attr
from synapse.replication.tcp.streams._base import (
- Stream,
StreamRow,
StreamUpdateResult,
Token,
+ _StreamFromIdGen,
)
if TYPE_CHECKING:
@@ -139,7 +139,7 @@ _EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
TypeToRow = {Row.TypeId: Row for Row in _EventRows}
-class EventsStream(Stream):
+class EventsStream(_StreamFromIdGen):
"""We received a new event, or an event went from being an outlier to not"""
NAME = "events"
@@ -147,9 +147,7 @@ class EventsStream(Stream):
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
super().__init__(
- hs.get_instance_name(),
- self._store._stream_id_gen.get_current_token_for_writer,
- self._update_function,
+ hs.get_instance_name(), self._update_function, self._store._stream_id_gen
)
async def _update_function(
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 4046bdec69..7f5af5852c 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -18,6 +18,7 @@ import attr
from synapse.replication.tcp.streams._base import (
Stream,
+ Token,
current_token_without_instance,
make_http_update_function,
)
@@ -47,7 +48,7 @@ class FederationStream(Stream):
# will be a real FederationSender, which has stubs for current_token and
# get_replication_rows.)
federation_sender = hs.get_federation_sender()
- current_token = current_token_without_instance(
+ self.current_token_func = current_token_without_instance(
federation_sender.get_current_token
)
update_function: Callable[
@@ -57,15 +58,21 @@ class FederationStream(Stream):
elif hs.should_send_federation():
# federation sender: Query master process
update_function = make_http_update_function(hs, self.NAME)
- current_token = self._stub_current_token
+ self.current_token_func = self._stub_current_token
else:
# other worker: stub out the update function (we're not interested in
# any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function
- current_token = self._stub_current_token
+ self.current_token_func = self._stub_current_token
- super().__init__(hs.get_instance_name(), current_token, update_function)
+ super().__init__(hs.get_instance_name(), update_function)
+
+ def current_token(self, instance_name: str) -> Token:
+ return self.current_token_func(instance_name)
+
+ def minimal_local_current_token(self) -> Token:
+ return self.current_token(self.local_instance_name)
@staticmethod
def _stub_current_token(instance_name: str) -> int:
diff --git a/synapse/replication/tcp/streams/partial_state.py b/synapse/replication/tcp/streams/partial_state.py
index a8ce5ffd72..ad181d7e93 100644
--- a/synapse/replication/tcp/streams/partial_state.py
+++ b/synapse/replication/tcp/streams/partial_state.py
@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING
import attr
-from synapse.replication.tcp.streams import Stream
+from synapse.replication.tcp.streams._base import _StreamFromIdGen
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -27,7 +27,7 @@ class UnPartialStatedRoomStreamRow:
room_id: str
-class UnPartialStatedRoomStream(Stream):
+class UnPartialStatedRoomStream(_StreamFromIdGen):
"""
Stream to notify about rooms becoming un-partial-stated;
that is, when the background sync finishes such that we now have full state for
@@ -41,8 +41,8 @@ class UnPartialStatedRoomStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- store.get_un_partial_stated_rooms_token,
store.get_un_partial_stated_rooms_from_stream,
+ store._un_partial_stated_rooms_stream_id_gen,
)
@@ -56,7 +56,7 @@ class UnPartialStatedEventStreamRow:
rejection_status_changed: bool
-class UnPartialStatedEventStream(Stream):
+class UnPartialStatedEventStream(_StreamFromIdGen):
"""
Stream to notify about events becoming un-partial-stated.
"""
@@ -68,6 +68,6 @@ class UnPartialStatedEventStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- store.get_un_partial_stated_events_token,
store.get_un_partial_stated_events_from_stream,
+ store._un_partial_stated_events_stream_id_gen,
)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index d2c874b9a8..9c3eafb562 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -134,6 +134,15 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ def get_minimal_local_current_token(self) -> int:
+ """Tries to return a minimal current token for the local instance,
+ i.e. for writers this would be the last successful write.
+
+ If local instance is not a writer (or has written yet) then falls back
+ to returning the normal "current token".
+ """
+
+ @abc.abstractmethod
def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
@@ -312,6 +321,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
def get_current_token_for_writer(self, instance_name: str) -> int:
return self.get_current_token()
+ def get_minimal_local_current_token(self) -> int:
+ return self.get_current_token()
+
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"""Generates and tracks stream IDs for a stream with multiple writers.
@@ -408,6 +420,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# The maximum stream ID that we have seen been allocated across any writer.
self._max_seen_allocated_stream_id = 1
+ # The maximum position of the local instance. This can be higher than
+ # the corresponding position in `current_positions` table when there are
+ # no active writes in progress.
+ self._max_position_of_local_instance = self._max_seen_allocated_stream_id
+
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged.
@@ -427,6 +444,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._current_positions.values(), default=1
)
+ # For the case where `stream_positions` is not up to date,
+ # `_persisted_upto_position` may be higher.
+ self._max_seen_allocated_stream_id = max(
+ self._max_seen_allocated_stream_id, self._persisted_upto_position
+ )
+
+ # Bump our local maximum position now that we've loaded things from the
+ # DB.
+ self._max_position_of_local_instance = self._max_seen_allocated_stream_id
+
if not writers:
# If there have been no explicit writers given then any instance can
# write to the stream. In which case, let's pre-seed our own
@@ -545,6 +572,14 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if instance == self._instance_name:
self._current_positions[instance] = stream_id
+ if self._writers:
+ # If we have explicit writers then make sure that each instance has
+ # a position.
+ for writer in self._writers:
+ self._current_positions.setdefault(
+ writer, self._persisted_upto_position
+ )
+
cur.close()
def _load_next_id_txn(self, txn: Cursor) -> int:
@@ -688,6 +723,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if new_cur:
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, new_cur)
+ self._max_position_of_local_instance = max(
+ curr, new_cur, self._max_position_of_local_instance
+ )
self._add_persisted_position(next_id)
@@ -702,10 +740,26 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# persisted up to position. This stops Synapse from doing a full table
# scan when a new writer announces itself over replication.
with self._lock:
- return self._return_factor * self._current_positions.get(
+ if self._instance_name == instance_name:
+ return self._return_factor * self._max_position_of_local_instance
+
+ pos = self._current_positions.get(
instance_name, self._persisted_upto_position
)
+ # We want to return the maximum "current token" that we can for a
+ # writer, this helps ensure that streams progress as fast as
+ # possible.
+ pos = max(pos, self._persisted_upto_position)
+
+ return self._return_factor * pos
+
+ def get_minimal_local_current_token(self) -> int:
+ with self._lock:
+ return self._return_factor * self._current_positions.get(
+ self._instance_name, self._persisted_upto_position
+ )
+
def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
@@ -774,6 +828,18 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
+ # Advance our local max position.
+ self._max_position_of_local_instance = max(
+ self._max_position_of_local_instance, self._persisted_upto_position
+ )
+
+ if not self._unfinished_ids and not self._in_flight_fetches:
+ # If we don't have anything in flight, it's safe to advance to the
+ # max seen stream ID.
+ self._max_position_of_local_instance = max(
+ self._max_seen_allocated_stream_id, self._max_position_of_local_instance
+ )
+
# We now iterate through the seen positions, discarding those that are
# less than the current min positions, and incrementing the min position
# if its exactly one greater.
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 9174fb0964..fd53b0644c 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -259,8 +259,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
- # The table is empty so we expect an empty map for positions
- self.assertEqual(id_gen.get_positions(), {})
+ # The table is empty so we expect the map for positions to have a dummy
+ # minimum value.
+ self.assertEqual(id_gen.get_positions(), {"master": 1})
def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled
@@ -349,15 +350,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
- # The first ID gen will notice that it can advance its token to 7 as it
- # has no in progress writes...
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
- # ... but the second ID gen doesn't know that.
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
# Try allocating a new ID gen and check that we only see position
@@ -398,6 +396,56 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen.advance("first", 8)
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+ def test_multi_instance_empty_row(self) -> None:
+ """Test that reads and writes from multiple processes are handled
+ correctly, when one of the writers starts without any rows.
+ """
+ # Insert some rows for two out of three of the ID gens.
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator(
+ "first", writers=["first", "second", "third"]
+ )
+ second_id_gen = self._create_id_generator(
+ "second", writers=["first", "second", "third"]
+ )
+ third_id_gen = self._create_id_generator(
+ "third", writers=["first", "second", "third"]
+ )
+
+ self.assertEqual(
+ first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+ )
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7)
+
+ self.assertEqual(
+ second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+ )
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async() -> None:
+ async with third_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(
+ third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+ )
+ self.assertEqual(third_id_gen.get_persisted_upto_position(), 7)
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(
+ third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
+ )
+
def test_get_next_txn(self) -> None:
"""Test that the `get_next_txn` function works correctly."""
@@ -600,6 +648,70 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
with self.assertRaises(IncorrectDatabaseSetup):
self._create_id_generator("first")
+ def test_minimal_local_token(self) -> None:
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+ second_id_gen = self._create_id_generator("second", writers=["first", "second"])
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
+
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(second_id_gen.get_minimal_local_current_token(), 7)
+
+ def test_current_token_gap(self) -> None:
+ """Test that getting the current token for a writer returns the maximal
+ token when there are no writes.
+ """
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator(
+ "first", writers=["first", "second", "third"]
+ )
+ second_id_gen = self._create_id_generator(
+ "second", writers=["first", "second", "third"]
+ )
+
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+ self.assertEqual(second_id_gen.get_current_token(), 7)
+
+ # Check that the first ID gen advancing causes the second ID gen to
+ # advance (as the second ID gen has nothing in flight).
+
+ async def _get_next_async() -> None:
+ async with first_id_gen.get_next_mult(2):
+ pass
+
+ self.get_success(_get_next_async())
+ second_id_gen.advance("first", 9)
+
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 9)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
+ self.assertEqual(second_id_gen.get_current_token(), 7)
+
+ # Check that the first ID gen advancing doesn't advance the second ID
+ # gen when the second ID gen has stuff in flight.
+ self.get_success(_get_next_async())
+
+ ctxmgr = second_id_gen.get_next()
+ self.get_success(ctxmgr.__aenter__())
+
+ second_id_gen.advance("first", 11)
+
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
+ self.assertEqual(second_id_gen.get_current_token(), 7)
+
+ self.get_success(ctxmgr.__aexit__(None, None, None))
+
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 12)
+ self.assertEqual(second_id_gen.get_current_token(), 7)
+
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
@@ -712,8 +824,8 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(_get_next_async())
- self.assertEqual(id_gen_1.get_positions(), {"first": -1})
- self.assertEqual(id_gen_2.get_positions(), {"first": -1})
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -1})
+ self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -1})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
@@ -822,11 +934,11 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
- self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
- self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(second_id_gen.get_persisted_upto_position(), 7)
|