summary refs log tree commit diff
path: root/synapse/replication/tcp/handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/handler.py')
-rw-r--r--synapse/replication/tcp/handler.py71
1 files changed, 29 insertions, 42 deletions
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py

index b620b29dfb..c197f6c26d 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -329,10 +329,17 @@ class ReplicationClientHandler: self.send_command(RdataCommand(stream_name, token, data)) -class DummyReplicationDataHandler: +class ReplicationDataHandler: """A replication data handler that simply discards all data. """ + def __init__(self, hs): + self.store = hs.get_datastore() + self.typing_handler = hs.get_typing_handler() + + self.slaved_store = hs.config.worker_app is not None + self.slaved_typing = not hs.config.server.handle_typing + async def on_rdata(self, stream_name: str, token: int, rows: list): """Called to handle a batch of replication data with a given stream token. @@ -345,7 +352,11 @@ class DummyReplicationDataHandler: rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - pass + if self.slaved_store: + self.store.process_replication_rows(stream_name, token, rows) + + if self.slaved_typing: + self.typing_handler.process_replication_rows(stream_name, token, rows) def get_streams_to_replicate(self) -> Dict[str, int]: """Called when a new connection has been established and we need to @@ -355,49 +366,25 @@ class DummyReplicationDataHandler: map from stream name to the most recent update we have for that stream (ie, the point we want to start replicating from) """ - return {} - - async def on_position(self, stream_name: str, token: int): - pass - + args = {} # type: Dict[str, int] -class WorkerReplicationDataHandler: - """A replication data handler that calls slave data stores. - """ - - def __init__(self, store): - self.store = store - - async def on_rdata(self, stream_name: str, token: int, rows: list): - """Called to handle a batch of replication data with a given stream token. + if self.slaved_store: + args = self.store.stream_positions() + user_account_data = args.pop("user_account_data", None) + room_account_data = args.pop("room_account_data", None) + if user_account_data: + args["account_data"] = user_account_data + elif room_account_data: + args["account_data"] = room_account_data - By default this just pokes the slave store. Can be overridden in subclasses to - handle more. + if self.slaved_typing: + args.update(self.typing_handler.stream_positions()) - Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. - """ - self.store.process_replication_rows(stream_name, token, rows) - - def get_streams_to_replicate(self) -> Dict[str, int]: - """Called when a new connection has been established and we need to - subscribe to streams. - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - args = self.store.stream_positions() - user_account_data = args.pop("user_account_data", None) - room_account_data = args.pop("room_account_data", None) - if user_account_data: - args["account_data"] = user_account_data - elif room_account_data: - args["account_data"] = room_account_data return args async def on_position(self, stream_name: str, token: int): - self.store.process_replication_rows(stream_name, token, []) + if self.slaved_store: + self.store.process_replication_rows(stream_name, token, []) + + if self.slaved_typing: + self.typing_handler.process_replication_rows(stream_name, token, [])