diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b620b29dfb..c197f6c26d 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -329,10 +329,17 @@ class ReplicationClientHandler:
self.send_command(RdataCommand(stream_name, token, data))
-class DummyReplicationDataHandler:
+class ReplicationDataHandler:
"""A replication data handler that simply discards all data.
"""
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.typing_handler = hs.get_typing_handler()
+
+ self.slaved_store = hs.config.worker_app is not None
+ self.slaved_typing = not hs.config.server.handle_typing
+
async def on_rdata(self, stream_name: str, token: int, rows: list):
"""Called to handle a batch of replication data with a given stream token.
@@ -345,7 +352,11 @@ class DummyReplicationDataHandler:
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
- pass
+ if self.slaved_store:
+ self.store.process_replication_rows(stream_name, token, rows)
+
+ if self.slaved_typing:
+ self.typing_handler.process_replication_rows(stream_name, token, rows)
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
@@ -355,49 +366,25 @@ class DummyReplicationDataHandler:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
- return {}
-
- async def on_position(self, stream_name: str, token: int):
- pass
-
+ args = {} # type: Dict[str, int]
-class WorkerReplicationDataHandler:
- """A replication data handler that calls slave data stores.
- """
-
- def __init__(self, store):
- self.store = store
-
- async def on_rdata(self, stream_name: str, token: int, rows: list):
- """Called to handle a batch of replication data with a given stream token.
+ if self.slaved_store:
+ args = self.store.stream_positions()
+ user_account_data = args.pop("user_account_data", None)
+ room_account_data = args.pop("room_account_data", None)
+ if user_account_data:
+ args["account_data"] = user_account_data
+ elif room_account_data:
+ args["account_data"] = room_account_data
- By default this just pokes the slave store. Can be overridden in subclasses to
- handle more.
+ if self.slaved_typing:
+ args.update(self.typing_handler.stream_positions())
- Args:
- stream_name (str): name of the replication stream for this batch of rows
- token (int): stream token for this batch of rows
- rows (list): a list of Stream.ROW_TYPE objects as returned by
- Stream.parse_row.
- """
- self.store.process_replication_rows(stream_name, token, rows)
-
- def get_streams_to_replicate(self) -> Dict[str, int]:
- """Called when a new connection has been established and we need to
- subscribe to streams.
-
- Returns:
- map from stream name to the most recent update we have for
- that stream (ie, the point we want to start replicating from)
- """
- args = self.store.stream_positions()
- user_account_data = args.pop("user_account_data", None)
- room_account_data = args.pop("room_account_data", None)
- if user_account_data:
- args["account_data"] = user_account_data
- elif room_account_data:
- args["account_data"] = room_account_data
return args
async def on_position(self, stream_name: str, token: int):
- self.store.process_replication_rows(stream_name, token, [])
+ if self.slaved_store:
+ self.store.process_replication_rows(stream_name, token, [])
+
+ if self.slaved_typing:
+ self.typing_handler.process_replication_rows(stream_name, token, [])
|