diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/replication/http/_base.py | 2 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 129 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/events.py | 8 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/federation.py | 15 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/partial_state.py | 10 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 68 |
6 files changed, 170 insertions, 62 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 63cf24a14d..7476839db5 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -238,7 +238,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): data[_STREAM_POSITION_KEY] = { "streams": { - stream.NAME: stream.current_token(local_instance_name) + stream.NAME: stream.minimal_local_current_token() for stream in streams }, "instance_name": local_instance_name, diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index c6088a0f99..5c4d228f3d 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -33,6 +33,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.util.id_generators import AbstractStreamIdGenerator logger = logging.getLogger(__name__) @@ -107,22 +108,10 @@ class Stream: def __init__( self, local_instance_name: str, - current_token_function: Callable[[str], Token], update_function: UpdateFunction, ): """Instantiate a Stream - `current_token_function` and `update_function` are callbacks which - should be implemented by subclasses. - - `current_token_function` takes an instance name, which is a writer to - the stream, and returns the position in the stream of the writer (as - viewed from the current process). On the writer process this is where - the writer has successfully written up to, whereas on other processes - this is the position which we have received updates up to over - replication. (Note that most streams have a single writer and so their - implementations ignore the instance name passed in). - `update_function` is called to get updates for this stream between a pair of stream tokens. See the `UpdateFunction` type definition for more info. @@ -133,12 +122,28 @@ class Stream: update_function: callback go get stream updates, as above """ self.local_instance_name = local_instance_name - self.current_token = current_token_function self.update_function = update_function # The token from which we last asked for updates self.last_token = self.current_token(self.local_instance_name) + def current_token(self, instance_name: str) -> Token: + """This takes an instance name, which is a writer to + the stream, and returns the position in the stream of the writer (as + viewed from the current process). + """ + # We can't make this an abstract class as it makes mypy unhappy. + raise NotImplementedError() + + def minimal_local_current_token(self) -> Token: + """Tries to return a minimal current token for the local instance, + i.e. for writers this would be the last successful write. + + If local instance is not a writer (or has written yet) then falls back + to returning the normal "current token". + """ + raise NotImplementedError() + def discard_updates_and_advance(self) -> None: """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. @@ -190,6 +195,25 @@ class Stream: return updates, upto_token, limited +class _StreamFromIdGen(Stream): + """Helper class for simple streams that use a stream ID generator""" + + def __init__( + self, + local_instance_name: str, + update_function: UpdateFunction, + stream_id_gen: "AbstractStreamIdGenerator", + ): + self._stream_id_gen = stream_id_gen + super().__init__(local_instance_name, update_function) + + def current_token(self, instance_name: str) -> Token: + return self._stream_id_gen.get_current_token_for_writer(instance_name) + + def minimal_local_current_token(self) -> Token: + return self._stream_id_gen.get_minimal_local_current_token() + + def current_token_without_instance( current_token: Callable[[], int] ) -> Callable[[str], int]: @@ -242,17 +266,21 @@ class BackfillStream(Stream): self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - self._current_token, self.store.get_all_new_backfill_event_rows, ) - def _current_token(self, instance_name: str) -> int: + def current_token(self, instance_name: str) -> Token: # The backfill stream over replication operates on *positive* numbers, # which means we need to negate it. return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name) + def minimal_local_current_token(self) -> Token: + # The backfill stream over replication operates on *positive* numbers, + # which means we need to negate it. + return -self.store._backfill_id_gen.get_minimal_local_current_token() -class PresenceStream(Stream): + +class PresenceStream(_StreamFromIdGen): @attr.s(slots=True, frozen=True, auto_attribs=True) class PresenceStreamRow: user_id: str @@ -283,9 +311,7 @@ class PresenceStream(Stream): update_function = make_http_update_function(hs, self.NAME) super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_current_presence_token), - update_function, + hs.get_instance_name(), update_function, store._presence_id_gen ) @@ -305,13 +331,18 @@ class PresenceFederationStream(Stream): ROW_TYPE = PresenceFederationStreamRow def __init__(self, hs: "HomeServer"): - federation_queue = hs.get_presence_handler().get_federation_queue() + self._federation_queue = hs.get_presence_handler().get_federation_queue() super().__init__( hs.get_instance_name(), - federation_queue.get_current_token, - federation_queue.get_replication_rows, + self._federation_queue.get_replication_rows, ) + def current_token(self, instance_name: str) -> Token: + return self._federation_queue.get_current_token(instance_name) + + def minimal_local_current_token(self) -> Token: + return self._federation_queue.get_current_token(self.local_instance_name) + class TypingStream(Stream): @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -341,20 +372,25 @@ class TypingStream(Stream): update_function: Callable[ [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]] ] = typing_writer_handler.get_all_typing_updates - current_token_function = typing_writer_handler.get_current_token + self.current_token_function = typing_writer_handler.get_current_token else: # Query the typing writer process update_function = make_http_update_function(hs, self.NAME) - current_token_function = hs.get_typing_handler().get_current_token + self.current_token_function = hs.get_typing_handler().get_current_token super().__init__( hs.get_instance_name(), - current_token_without_instance(current_token_function), update_function, ) + def current_token(self, instance_name: str) -> Token: + return self.current_token_function() + + def minimal_local_current_token(self) -> Token: + return self.current_token_function() -class ReceiptsStream(Stream): + +class ReceiptsStream(_StreamFromIdGen): @attr.s(slots=True, frozen=True, auto_attribs=True) class ReceiptsStreamRow: room_id: str @@ -371,12 +407,12 @@ class ReceiptsStream(Stream): store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_max_receipt_stream_id), store.get_all_updated_receipts, + store._receipts_id_gen, ) -class PushRulesStream(Stream): +class PushRulesStream(_StreamFromIdGen): """A user has changed their push rules""" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -387,20 +423,16 @@ class PushRulesStream(Stream): ROW_TYPE = PushRulesStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastores().main + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - self._current_token, - self.store.get_all_push_rule_updates, + store.get_all_push_rule_updates, + store._push_rules_stream_id_gen, ) - def _current_token(self, instance_name: str) -> int: - push_rules_token = self.store.get_max_push_rules_stream_id() - return push_rules_token - -class PushersStream(Stream): +class PushersStream(_StreamFromIdGen): """A user has added/changed/removed a pusher""" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -418,8 +450,8 @@ class PushersStream(Stream): super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_pushers_stream_token), store.get_all_updated_pushers_rows, + store._pushers_id_gen, ) @@ -447,15 +479,20 @@ class CachesStream(Stream): ROW_TYPE = CachesStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - store.get_cache_stream_token_for_writer, - store.get_all_updated_caches, + self.store.get_all_updated_caches, ) + def current_token(self, instance_name: str) -> Token: + return self.store.get_cache_stream_token_for_writer(instance_name) + + def minimal_local_current_token(self) -> Token: + return self.current_token(self.local_instance_name) + -class DeviceListsStream(Stream): +class DeviceListsStream(_StreamFromIdGen): """Either a user has updated their devices or a remote server needs to be told about a device update. """ @@ -473,8 +510,8 @@ class DeviceListsStream(Stream): self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(self.store.get_device_stream_token), self._update_function, + self.store._device_list_id_gen, ) async def _update_function( @@ -525,7 +562,7 @@ class DeviceListsStream(Stream): return updates, upper_limit_token, devices_limited or signatures_limited -class ToDeviceStream(Stream): +class ToDeviceStream(_StreamFromIdGen): """New to_device messages for a client""" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -539,12 +576,12 @@ class ToDeviceStream(Stream): store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_to_device_stream_token), store.get_all_new_device_messages, + store._device_inbox_id_gen, ) -class AccountDataStream(Stream): +class AccountDataStream(_StreamFromIdGen): """Global or per room account data was changed""" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -560,8 +597,8 @@ class AccountDataStream(Stream): self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(self.store.get_max_account_data_stream_id), self._update_function, + self.store._account_data_id_gen, ) async def _update_function( diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index da6d948e1b..38823113d8 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast import attr from synapse.replication.tcp.streams._base import ( - Stream, StreamRow, StreamUpdateResult, Token, + _StreamFromIdGen, ) if TYPE_CHECKING: @@ -139,7 +139,7 @@ _EventRows: Tuple[Type[BaseEventsStreamRow], ...] = ( TypeToRow = {Row.TypeId: Row for Row in _EventRows} -class EventsStream(Stream): +class EventsStream(_StreamFromIdGen): """We received a new event, or an event went from being an outlier to not""" NAME = "events" @@ -147,9 +147,7 @@ class EventsStream(Stream): def __init__(self, hs: "HomeServer"): self._store = hs.get_datastores().main super().__init__( - hs.get_instance_name(), - self._store._stream_id_gen.get_current_token_for_writer, - self._update_function, + hs.get_instance_name(), self._update_function, self._store._stream_id_gen ) async def _update_function( diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 4046bdec69..7f5af5852c 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -18,6 +18,7 @@ import attr from synapse.replication.tcp.streams._base import ( Stream, + Token, current_token_without_instance, make_http_update_function, ) @@ -47,7 +48,7 @@ class FederationStream(Stream): # will be a real FederationSender, which has stubs for current_token and # get_replication_rows.) federation_sender = hs.get_federation_sender() - current_token = current_token_without_instance( + self.current_token_func = current_token_without_instance( federation_sender.get_current_token ) update_function: Callable[ @@ -57,15 +58,21 @@ class FederationStream(Stream): elif hs.should_send_federation(): # federation sender: Query master process update_function = make_http_update_function(hs, self.NAME) - current_token = self._stub_current_token + self.current_token_func = self._stub_current_token else: # other worker: stub out the update function (we're not interested in # any updates so when we get a POSITION we do nothing) update_function = self._stub_update_function - current_token = self._stub_current_token + self.current_token_func = self._stub_current_token - super().__init__(hs.get_instance_name(), current_token, update_function) + super().__init__(hs.get_instance_name(), update_function) + + def current_token(self, instance_name: str) -> Token: + return self.current_token_func(instance_name) + + def minimal_local_current_token(self) -> Token: + return self.current_token(self.local_instance_name) @staticmethod def _stub_current_token(instance_name: str) -> int: diff --git a/synapse/replication/tcp/streams/partial_state.py b/synapse/replication/tcp/streams/partial_state.py index a8ce5ffd72..ad181d7e93 100644 --- a/synapse/replication/tcp/streams/partial_state.py +++ b/synapse/replication/tcp/streams/partial_state.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING import attr -from synapse.replication.tcp.streams import Stream +from synapse.replication.tcp.streams._base import _StreamFromIdGen if TYPE_CHECKING: from synapse.server import HomeServer @@ -27,7 +27,7 @@ class UnPartialStatedRoomStreamRow: room_id: str -class UnPartialStatedRoomStream(Stream): +class UnPartialStatedRoomStream(_StreamFromIdGen): """ Stream to notify about rooms becoming un-partial-stated; that is, when the background sync finishes such that we now have full state for @@ -41,8 +41,8 @@ class UnPartialStatedRoomStream(Stream): store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - store.get_un_partial_stated_rooms_token, store.get_un_partial_stated_rooms_from_stream, + store._un_partial_stated_rooms_stream_id_gen, ) @@ -56,7 +56,7 @@ class UnPartialStatedEventStreamRow: rejection_status_changed: bool -class UnPartialStatedEventStream(Stream): +class UnPartialStatedEventStream(_StreamFromIdGen): """ Stream to notify about events becoming un-partial-stated. """ @@ -68,6 +68,6 @@ class UnPartialStatedEventStream(Stream): store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - store.get_un_partial_stated_events_token, store.get_un_partial_stated_events_from_stream, + store._un_partial_stated_events_stream_id_gen, ) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index d2c874b9a8..9c3eafb562 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -134,6 +134,15 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod + def get_minimal_local_current_token(self) -> int: + """Tries to return a minimal current token for the local instance, + i.e. for writers this would be the last successful write. + + If local instance is not a writer (or has written yet) then falls back + to returning the normal "current token". + """ + + @abc.abstractmethod def get_next(self) -> AsyncContextManager[int]: """ Usage: @@ -312,6 +321,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator): def get_current_token_for_writer(self, instance_name: str) -> int: return self.get_current_token() + def get_minimal_local_current_token(self) -> int: + return self.get_current_token() + class MultiWriterIdGenerator(AbstractStreamIdGenerator): """Generates and tracks stream IDs for a stream with multiple writers. @@ -408,6 +420,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): # The maximum stream ID that we have seen been allocated across any writer. self._max_seen_allocated_stream_id = 1 + # The maximum position of the local instance. This can be higher than + # the corresponding position in `current_positions` table when there are + # no active writes in progress. + self._max_position_of_local_instance = self._max_seen_allocated_stream_id + self._sequence_gen = PostgresSequenceGenerator(sequence_name) # We check that the table and sequence haven't diverged. @@ -427,6 +444,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): self._current_positions.values(), default=1 ) + # For the case where `stream_positions` is not up to date, + # `_persisted_upto_position` may be higher. + self._max_seen_allocated_stream_id = max( + self._max_seen_allocated_stream_id, self._persisted_upto_position + ) + + # Bump our local maximum position now that we've loaded things from the + # DB. + self._max_position_of_local_instance = self._max_seen_allocated_stream_id + if not writers: # If there have been no explicit writers given then any instance can # write to the stream. In which case, let's pre-seed our own @@ -545,6 +572,14 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): if instance == self._instance_name: self._current_positions[instance] = stream_id + if self._writers: + # If we have explicit writers then make sure that each instance has + # a position. + for writer in self._writers: + self._current_positions.setdefault( + writer, self._persisted_upto_position + ) + cur.close() def _load_next_id_txn(self, txn: Cursor) -> int: @@ -688,6 +723,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): if new_cur: curr = self._current_positions.get(self._instance_name, 0) self._current_positions[self._instance_name] = max(curr, new_cur) + self._max_position_of_local_instance = max( + curr, new_cur, self._max_position_of_local_instance + ) self._add_persisted_position(next_id) @@ -702,10 +740,26 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): # persisted up to position. This stops Synapse from doing a full table # scan when a new writer announces itself over replication. with self._lock: - return self._return_factor * self._current_positions.get( + if self._instance_name == instance_name: + return self._return_factor * self._max_position_of_local_instance + + pos = self._current_positions.get( instance_name, self._persisted_upto_position ) + # We want to return the maximum "current token" that we can for a + # writer, this helps ensure that streams progress as fast as + # possible. + pos = max(pos, self._persisted_upto_position) + + return self._return_factor * pos + + def get_minimal_local_current_token(self) -> int: + with self._lock: + return self._return_factor * self._current_positions.get( + self._instance_name, self._persisted_upto_position + ) + def get_positions(self) -> Dict[str, int]: """Get a copy of the current positon map. @@ -774,6 +828,18 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): self._persisted_upto_position = max(min_curr, self._persisted_upto_position) + # Advance our local max position. + self._max_position_of_local_instance = max( + self._max_position_of_local_instance, self._persisted_upto_position + ) + + if not self._unfinished_ids and not self._in_flight_fetches: + # If we don't have anything in flight, it's safe to advance to the + # max seen stream ID. + self._max_position_of_local_instance = max( + self._max_seen_allocated_stream_id, self._max_position_of_local_instance + ) + # We now iterate through the seen positions, discarding those that are # less than the current min positions, and incrementing the min position # if its exactly one greater. |