diff options
Diffstat (limited to 'synapse/replication/tcp/streams/_base.py')
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 129 |
1 files changed, 83 insertions, 46 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index c6088a0f99..5c4d228f3d 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -33,6 +33,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.util.id_generators import AbstractStreamIdGenerator logger = logging.getLogger(__name__) @@ -107,22 +108,10 @@ class Stream: def __init__( self, local_instance_name: str, - current_token_function: Callable[[str], Token], update_function: UpdateFunction, ): """Instantiate a Stream - `current_token_function` and `update_function` are callbacks which - should be implemented by subclasses. - - `current_token_function` takes an instance name, which is a writer to - the stream, and returns the position in the stream of the writer (as - viewed from the current process). On the writer process this is where - the writer has successfully written up to, whereas on other processes - this is the position which we have received updates up to over - replication. (Note that most streams have a single writer and so their - implementations ignore the instance name passed in). - `update_function` is called to get updates for this stream between a pair of stream tokens. See the `UpdateFunction` type definition for more info. @@ -133,12 +122,28 @@ class Stream: update_function: callback go get stream updates, as above """ self.local_instance_name = local_instance_name - self.current_token = current_token_function self.update_function = update_function # The token from which we last asked for updates self.last_token = self.current_token(self.local_instance_name) + def current_token(self, instance_name: str) -> Token: + """This takes an instance name, which is a writer to + the stream, and returns the position in the stream of the writer (as + viewed from the current process). + """ + # We can't make this an abstract class as it makes mypy unhappy. + raise NotImplementedError() + + def minimal_local_current_token(self) -> Token: + """Tries to return a minimal current token for the local instance, + i.e. for writers this would be the last successful write. + + If local instance is not a writer (or has written yet) then falls back + to returning the normal "current token". + """ + raise NotImplementedError() + def discard_updates_and_advance(self) -> None: """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. @@ -190,6 +195,25 @@ class Stream: return updates, upto_token, limited +class _StreamFromIdGen(Stream): + """Helper class for simple streams that use a stream ID generator""" + + def __init__( + self, + local_instance_name: str, + update_function: UpdateFunction, + stream_id_gen: "AbstractStreamIdGenerator", + ): + self._stream_id_gen = stream_id_gen + super().__init__(local_instance_name, update_function) + + def current_token(self, instance_name: str) -> Token: + return self._stream_id_gen.get_current_token_for_writer(instance_name) + + def minimal_local_current_token(self) -> Token: + return self._stream_id_gen.get_minimal_local_current_token() + + def current_token_without_instance( current_token: Callable[[], int] ) -> Callable[[str], int]: @@ -242,17 +266,21 @@ class BackfillStream(Stream): self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - self._current_token, self.store.get_all_new_backfill_event_rows, ) - def _current_token(self, instance_name: str) -> int: + def current_token(self, instance_name: str) -> Token: # The backfill stream over replication operates on *positive* numbers, # which means we need to negate it. return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name) + def minimal_local_current_token(self) -> Token: + # The backfill stream over replication operates on *positive* numbers, + # which means we need to negate it. + return -self.store._backfill_id_gen.get_minimal_local_current_token() -class PresenceStream(Stream): + +class PresenceStream(_StreamFromIdGen): @attr.s(slots=True, frozen=True, auto_attribs=True) class PresenceStreamRow: user_id: str @@ -283,9 +311,7 @@ class PresenceStream(Stream): update_function = make_http_update_function(hs, self.NAME) super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_current_presence_token), - update_function, + hs.get_instance_name(), update_function, store._presence_id_gen ) @@ -305,13 +331,18 @@ class PresenceFederationStream(Stream): ROW_TYPE = PresenceFederationStreamRow def __init__(self, hs: "HomeServer"): - federation_queue = hs.get_presence_handler().get_federation_queue() + self._federation_queue = hs.get_presence_handler().get_federation_queue() super().__init__( hs.get_instance_name(), - federation_queue.get_current_token, - federation_queue.get_replication_rows, + self._federation_queue.get_replication_rows, ) + def current_token(self, instance_name: str) -> Token: + return self._federation_queue.get_current_token(instance_name) + + def minimal_local_current_token(self) -> Token: + return self._federation_queue.get_current_token(self.local_instance_name) + class TypingStream(Stream): @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -341,20 +372,25 @@ class TypingStream(Stream): update_function: Callable[ [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]] ] = typing_writer_handler.get_all_typing_updates - current_token_function = typing_writer_handler.get_current_token + self.current_token_function = typing_writer_handler.get_current_token else: # Query the typing writer process update_function = make_http_update_function(hs, self.NAME) - current_token_function = hs.get_typing_handler().get_current_token + self.current_token_function = hs.get_typing_handler().get_current_token super().__init__( hs.get_instance_name(), - current_token_without_instance(current_token_function), update_function, ) + def current_token(self, instance_name: str) -> Token: + return self.current_token_function() + + def minimal_local_current_token(self) -> Token: + return self.current_token_function() -class ReceiptsStream(Stream): + +class ReceiptsStream(_StreamFromIdGen): @attr.s(slots=True, frozen=True, auto_attribs=True) class ReceiptsStreamRow: room_id: str @@ -371,12 +407,12 @@ class ReceiptsStream(Stream): store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_max_receipt_stream_id), store.get_all_updated_receipts, + store._receipts_id_gen, ) -class PushRulesStream(Stream): +class PushRulesStream(_StreamFromIdGen): """A user has changed their push rules""" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -387,20 +423,16 @@ class PushRulesStream(Stream): ROW_TYPE = PushRulesStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastores().main + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - self._current_token, - self.store.get_all_push_rule_updates, + store.get_all_push_rule_updates, + store._push_rules_stream_id_gen, ) - def _current_token(self, instance_name: str) -> int: - push_rules_token = self.store.get_max_push_rules_stream_id() - return push_rules_token - -class PushersStream(Stream): +class PushersStream(_StreamFromIdGen): """A user has added/changed/removed a pusher""" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -418,8 +450,8 @@ class PushersStream(Stream): super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_pushers_stream_token), store.get_all_updated_pushers_rows, + store._pushers_id_gen, ) @@ -447,15 +479,20 @@ class CachesStream(Stream): ROW_TYPE = CachesStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - store.get_cache_stream_token_for_writer, - store.get_all_updated_caches, + self.store.get_all_updated_caches, ) + def current_token(self, instance_name: str) -> Token: + return self.store.get_cache_stream_token_for_writer(instance_name) + + def minimal_local_current_token(self) -> Token: + return self.current_token(self.local_instance_name) + -class DeviceListsStream(Stream): +class DeviceListsStream(_StreamFromIdGen): """Either a user has updated their devices or a remote server needs to be told about a device update. """ @@ -473,8 +510,8 @@ class DeviceListsStream(Stream): self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(self.store.get_device_stream_token), self._update_function, + self.store._device_list_id_gen, ) async def _update_function( @@ -525,7 +562,7 @@ class DeviceListsStream(Stream): return updates, upper_limit_token, devices_limited or signatures_limited -class ToDeviceStream(Stream): +class ToDeviceStream(_StreamFromIdGen): """New to_device messages for a client""" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -539,12 +576,12 @@ class ToDeviceStream(Stream): store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_to_device_stream_token), store.get_all_new_device_messages, + store._device_inbox_id_gen, ) -class AccountDataStream(Stream): +class AccountDataStream(_StreamFromIdGen): """Global or per room account data was changed""" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -560,8 +597,8 @@ class AccountDataStream(Stream): self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(self.store.get_max_account_data_stream_id), self._update_function, + self.store._account_data_id_gen, ) async def _update_function( |