diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index ced69ee904..ce5d651cb8 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,22 +14,36 @@
"""A replication client for use by synapse workers.
"""
import logging
-from typing import TYPE_CHECKING, Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from twisted.internet.defer import Deferred
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.api.constants import EventTypes
+from synapse.federation import send_queue
+from synapse.federation.sender import FederationSender
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.streams import TypingStream
+from synapse.replication.tcp.streams import (
+ AccountDataStream,
+ DeviceListsStream,
+ GroupServerStream,
+ PresenceStream,
+ PushersStream,
+ PushRulesStream,
+ ReceiptsStream,
+ TagAccountDataStream,
+ ToDeviceStream,
+ TypingStream,
+)
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
EventsStreamRow,
)
-from synapse.types import PersistedEventPosition, UserID
-from synapse.util.async_helpers import timeout_deferred
+from synapse.types import PersistedEventPosition, ReadReceipt, UserID
+from synapse.util.async_helpers import Linearizer, timeout_deferred
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -105,6 +119,14 @@ class ReplicationDataHandler:
self._instance_name = hs.get_instance_name()
self._typing_handler = hs.get_typing_handler()
+ self._notify_pushers = hs.config.start_pushers
+ self._pusher_pool = hs.get_pusherpool()
+ self._presence_handler = hs.get_presence_handler()
+
+ self.send_handler = None # type: Optional[FederationSenderHandler]
+ if hs.should_send_federation():
+ self.send_handler = FederationSenderHandler(hs)
+
# Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]]
@@ -125,13 +147,53 @@ class ReplicationDataHandler:
"""
self.store.process_replication_rows(stream_name, instance_name, token, rows)
+ if self.send_handler:
+ await self.send_handler.process_replication_rows(stream_name, token, rows)
+
if stream_name == TypingStream.NAME:
self._typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event(
"typing_key", token, rooms=[row.room_id for row in rows]
)
-
- if stream_name == EventsStream.NAME:
+ elif stream_name == PushRulesStream.NAME:
+ self.notifier.on_new_event(
+ "push_rules_key", token, users=[row.user_id for row in rows]
+ )
+ elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
+ self.notifier.on_new_event(
+ "account_data_key", token, users=[row.user_id for row in rows]
+ )
+ elif stream_name == ReceiptsStream.NAME:
+ self.notifier.on_new_event(
+ "receipt_key", token, rooms=[row.room_id for row in rows]
+ )
+ await self._pusher_pool.on_new_receipts(
+ token, token, {row.room_id for row in rows}
+ )
+ elif stream_name == ToDeviceStream.NAME:
+ entities = [row.entity for row in rows if row.entity.startswith("@")]
+ if entities:
+ self.notifier.on_new_event("to_device_key", token, users=entities)
+ elif stream_name == DeviceListsStream.NAME:
+ all_room_ids = set() # type: Set[str]
+ for row in rows:
+ if row.entity.startswith("@"):
+ room_ids = await self.store.get_rooms_for_user(row.entity)
+ all_room_ids.update(room_ids)
+ self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
+ elif stream_name == GroupServerStream.NAME:
+ self.notifier.on_new_event(
+ "groups_key", token, users=[row.user_id for row in rows]
+ )
+ elif stream_name == PushersStream.NAME:
+ for row in rows:
+ if row.deleted:
+ self.stop_pusher(row.user_id, row.app_id, row.pushkey)
+ else:
+ await self.start_pusher(row.user_id, row.app_id, row.pushkey)
+ elif stream_name == PresenceStream.NAME:
+ await self._presence_handler.process_replication_rows(token, rows)
+ elif stream_name == EventsStream.NAME:
# We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows.
for row in rows:
@@ -190,7 +252,7 @@ class ReplicationDataHandler:
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
async def on_position(self, stream_name: str, instance_name: str, token: int):
- self.store.process_replication_rows(stream_name, instance_name, token, [])
+ await self.on_rdata(stream_name, instance_name, token, [])
# We poke the generic "replication" notifier to wake anything up that
# may be streaming.
@@ -199,6 +261,11 @@ class ReplicationDataHandler:
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
+ # Let's wake up the transaction queue for the server in case we have
+ # pending stuff to send to it.
+ if self.send_handler:
+ self.send_handler.wake_destination(server)
+
async def wait_for_stream_position(
self, instance_name: str, stream_name: str, position: int
):
@@ -235,3 +302,153 @@ class ReplicationDataHandler:
logger.info(
"Finished waiting for repl stream %r to reach %s", stream_name, position
)
+
+ def stop_pusher(self, user_id, app_id, pushkey):
+ if not self._notify_pushers:
+ return
+
+ key = "%s:%s" % (app_id, pushkey)
+ pushers_for_user = self._pusher_pool.pushers.get(user_id, {})
+ pusher = pushers_for_user.pop(key, None)
+ if pusher is None:
+ return
+ logger.info("Stopping pusher %r / %r", user_id, key)
+ pusher.on_stop()
+
+ async def start_pusher(self, user_id, app_id, pushkey):
+ if not self._notify_pushers:
+ return
+
+ key = "%s:%s" % (app_id, pushkey)
+ logger.info("Starting pusher %r / %r", user_id, key)
+ return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
+
+
+class FederationSenderHandler:
+ """Processes the fedration replication stream
+
+ This class is only instantiate on the worker responsible for sending outbound
+ federation transactions. It receives rows from the replication stream and forwards
+ the appropriate entries to the FederationSender class.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ assert hs.should_send_federation()
+
+ self.store = hs.get_datastore()
+ self._is_mine_id = hs.is_mine_id
+ self._hs = hs
+
+ # We need to make a temporary value to ensure that mypy picks up the
+ # right type. We know we should have a federation sender instance since
+ # `should_send_federation` is True.
+ sender = hs.get_federation_sender()
+ assert isinstance(sender, FederationSender)
+ self.federation_sender = sender
+
+ # Stores the latest position in the federation stream we've gotten up
+ # to. This is always set before we use it.
+ self.federation_position = None # type: Optional[int]
+
+ self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
+
+ def wake_destination(self, server: str):
+ self.federation_sender.wake_destination(server)
+
+ async def process_replication_rows(self, stream_name, token, rows):
+ # The federation stream contains things that we want to send out, e.g.
+ # presence, typing, etc.
+ if stream_name == "federation":
+ send_queue.process_rows_for_federation(self.federation_sender, rows)
+ await self.update_token(token)
+
+ # ... and when new receipts happen
+ elif stream_name == ReceiptsStream.NAME:
+ await self._on_new_receipts(rows)
+
+ # ... as well as device updates and messages
+ elif stream_name == DeviceListsStream.NAME:
+ # The entities are either user IDs (starting with '@') whose devices
+ # have changed, or remote servers that we need to tell about
+ # changes.
+ hosts = {row.entity for row in rows if not row.entity.startswith("@")}
+ for host in hosts:
+ self.federation_sender.send_device_messages(host)
+
+ elif stream_name == ToDeviceStream.NAME:
+ # The to_device stream includes stuff to be pushed to both local
+ # clients and remote servers, so we ignore entities that start with
+ # '@' (since they'll be local users rather than destinations).
+ hosts = {row.entity for row in rows if not row.entity.startswith("@")}
+ for host in hosts:
+ self.federation_sender.send_device_messages(host)
+
+ async def _on_new_receipts(self, rows):
+ """
+ Args:
+ rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
+ new receipts to be processed
+ """
+ for receipt in rows:
+ # we only want to send on receipts for our own users
+ if not self._is_mine_id(receipt.user_id):
+ continue
+ receipt_info = ReadReceipt(
+ receipt.room_id,
+ receipt.receipt_type,
+ receipt.user_id,
+ [receipt.event_id],
+ receipt.data,
+ )
+ await self.federation_sender.send_read_receipt(receipt_info)
+
+ async def update_token(self, token):
+ """Update the record of where we have processed to in the federation stream.
+
+ Called after we have processed a an update received over replication. Sends
+ a FEDERATION_ACK back to the master, and stores the token that we have processed
+ in `federation_stream_position` so that we can restart where we left off.
+ """
+ self.federation_position = token
+
+ # We save and send the ACK to master asynchronously, so we don't block
+ # processing on persistence. We don't need to do this operation for
+ # every single RDATA we receive, we just need to do it periodically.
+
+ if self._fed_position_linearizer.is_queued(None):
+ # There is already a task queued up to save and send the token, so
+ # no need to queue up another task.
+ return
+
+ run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
+
+ async def _save_and_send_ack(self):
+ """Save the current federation position in the database and send an ACK
+ to master with where we're up to.
+ """
+ # We should only be calling this once we've got a token.
+ assert self.federation_position is not None
+
+ try:
+ # We linearize here to ensure we don't have races updating the token
+ #
+ # XXX this appears to be redundant, since the ReplicationCommandHandler
+ # has a linearizer which ensures that we only process one line of
+ # replication data at a time. Should we remove it, or is it doing useful
+ # service for robustness? Or could we replace it with an assertion that
+ # we're not being re-entered?
+
+ with (await self._fed_position_linearizer.queue(None)):
+ # We persist and ack the same position, so we take a copy of it
+ # here as otherwise it can get modified from underneath us.
+ current_position = self.federation_position
+
+ await self.store.update_federation_out_pos(
+ "federation", current_position
+ )
+
+ # We ACK this token over replication so that the master can drop
+ # its in memory queues
+ self._hs.get_tcp_replication().send_federation_ack(current_position)
+ except Exception:
+ logger.exception("Error updating federation stream position")
|