diff options
Diffstat (limited to 'synapse/replication/tcp/streams/_base.py')
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 50 |
1 files changed, 37 insertions, 13 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4af1afd119..b0f87c365b 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple +from typing import Any, Awaitable, Callable, List, Optional, Tuple import attr @@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] # # The arguments are: # +# * instance_name: the writer of the stream # * from_token: the previous stream token: the starting point for fetching the # updates # * to_token: the new stream token: the point to get updates up to @@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] # If there are more updates available, it should set `limited` in the result, and # it will be called again to get the next batch. # -UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]] +UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] class Stream(object): @@ -93,6 +94,7 @@ class Stream(object): def __init__( self, + local_instance_name: str, current_token_function: Callable[[], Token], update_function: UpdateFunction, ): @@ -108,9 +110,11 @@ class Stream(object): stream tokens. See the UpdateFunction type definition for more info. Args: + local_instance_name: The instance name of the current process current_token_function: callback to get the current token, as above update_function: callback go get stream updates, as above """ + self.local_instance_name = local_instance_name self.current_token = current_token_function self.update_function = update_function @@ -135,14 +139,14 @@ class Stream(object): """ current_token = self.current_token() updates, current_token, limited = await self.get_updates_since( - self.last_token, current_token + self.local_instance_name, self.last_token, current_token ) self.last_token = current_token return updates, current_token, limited async def get_updates_since( - self, from_token: Token, upto_token: Token + self, instance_name: str, from_token: Token, upto_token: Token ) -> StreamUpdateResult: """Like get_updates except allows specifying from when we should stream updates @@ -160,19 +164,19 @@ class Stream(object): return [], upto_token, False updates, upto_token, limited = await self.update_function( - from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, + instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, ) return updates, upto_token, limited def db_query_to_update_function( - query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]] + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] ) -> UpdateFunction: """Wraps a db query function which returns a list of rows to make it suitable for use as an `update_function` for the Stream class """ - async def update_function(from_token, upto_token, limit): + async def update_function(instance_name, from_token, upto_token, limit): rows = await query_function(from_token, upto_token, limit) updates = [(row[0], row[1:]) for row in rows] limited = False @@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction: client = ReplicationGetStreamUpdates.make_client(hs) async def update_function( - from_token: int, upto_token: int, limit: int + instance_name: str, from_token: int, upto_token: int, limit: int ) -> StreamUpdateResult: result = await client( - stream_name=stream_name, from_token=from_token, upto_token=upto_token, + instance_name=instance_name, + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, ) return result["updates"], result["upto_token"], result["limited"] @@ -226,6 +233,7 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_current_backfill_token, db_query_to_update_function(store.get_all_new_backfill_event_rows), ) @@ -261,7 +269,9 @@ class PresenceStream(Stream): # Query master process update_function = make_http_update_function(hs, self.NAME) - super().__init__(store.get_current_presence_token, update_function) + super().__init__( + hs.get_instance_name(), store.get_current_presence_token, update_function + ) class TypingStream(Stream): @@ -284,7 +294,9 @@ class TypingStream(Stream): # Query master process update_function = make_http_update_function(hs, self.NAME) - super().__init__(typing_handler.get_current_token, update_function) + super().__init__( + hs.get_instance_name(), typing_handler.get_current_token, update_function + ) class ReceiptsStream(Stream): @@ -305,6 +317,7 @@ class ReceiptsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_max_receipt_stream_id, db_query_to_update_function(store.get_all_updated_receipts), ) @@ -322,14 +335,16 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() super(PushRulesStream, self).__init__( - self._current_token, self._update_function + hs.get_instance_name(), self._current_token, self._update_function ) def _current_token(self) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - async def _update_function(self, from_token: Token, to_token: Token, limit: int): + async def _update_function( + self, instance_name: str, from_token: Token, to_token: Token, limit: int + ): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) limited = False @@ -356,6 +371,7 @@ class PushersStream(Stream): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_pushers_stream_token, db_query_to_update_function(store.get_all_updated_pushers_rows), ) @@ -387,6 +403,7 @@ class CachesStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_cache_stream_token, db_query_to_update_function(store.get_all_updated_caches), ) @@ -412,6 +429,7 @@ class PublicRoomsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_current_public_room_stream_id, db_query_to_update_function(store.get_all_new_public_rooms), ) @@ -432,6 +450,7 @@ class DeviceListsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_device_stream_token, db_query_to_update_function(store.get_all_device_list_changes_for_remotes), ) @@ -449,6 +468,7 @@ class ToDeviceStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_to_device_stream_token, db_query_to_update_function(store.get_all_new_device_messages), ) @@ -468,6 +488,7 @@ class TagAccountDataStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_max_account_data_stream_id, db_query_to_update_function(store.get_all_updated_tags), ) @@ -487,6 +508,7 @@ class AccountDataStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() super().__init__( + hs.get_instance_name(), self.store.get_max_account_data_stream_id, db_query_to_update_function(self._update_function), ) @@ -517,6 +539,7 @@ class GroupServerStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_group_stream_token, db_query_to_update_function(store.get_all_groups_changes), ) @@ -534,6 +557,7 @@ class UserSignatureStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_device_stream_token, db_query_to_update_function( store.get_all_user_signature_changes_for_remotes |