diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2f59245058..e4f2201c92 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -21,7 +21,7 @@ from twisted.internet.interfaces import IAddress, IConnector
from twisted.internet.protocol import ReconnectingClientFactory
from twisted.python.failure import Failure
-from synapse.api.constants import EventTypes, ReceiptTypes
+from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.federation import send_queue
from synapse.federation.sender import FederationSender
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
@@ -219,6 +219,21 @@ class ReplicationDataHandler:
membership=row.data.membership,
)
+ # If this event is a join, make a note of it so we have an accurate
+ # cross-worker room rate limit.
+ # TODO: Erik said we should exclude rows that came from ex_outliers
+ # here, but I don't see how we can determine that. I guess we could
+ # add a flag to row.data?
+ if (
+ row.data.type == EventTypes.Member
+ and row.data.membership == Membership.JOIN
+ and not row.data.outlier
+ ):
+ # TODO retrieve the previous state, and exclude join -> join transitions
+ self.notifier.notify_user_joined_room(
+ row.data.event_id, row.data.room_id
+ )
+
await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows
)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index e1cbfa50eb..0f166d16aa 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -35,7 +35,6 @@ from twisted.internet.protocol import ReconnectingClientFactory
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
@@ -332,46 +331,31 @@ class ReplicationCommandHandler:
def start_replication(self, hs: "HomeServer") -> None:
"""Helper method to start replication."""
- if hs.config.redis.redis_enabled:
- from synapse.replication.tcp.redis import (
- RedisDirectTcpReplicationClientFactory,
- )
+ from synapse.replication.tcp.redis import RedisDirectTcpReplicationClientFactory
- # First let's ensure that we have a ReplicationStreamer started.
- hs.get_replication_streamer()
+ # First let's ensure that we have a ReplicationStreamer started.
+ hs.get_replication_streamer()
- # We need two connections to redis, one for the subscription stream and
- # one to send commands to (as you can't send further redis commands to a
- # connection after SUBSCRIBE is called).
+ # We need two connections to redis, one for the subscription stream and
+ # one to send commands to (as you can't send further redis commands to a
+ # connection after SUBSCRIBE is called).
- # First create the connection for sending commands.
- outbound_redis_connection = hs.get_outbound_redis_connection()
+ # First create the connection for sending commands.
+ outbound_redis_connection = hs.get_outbound_redis_connection()
- # Now create the factory/connection for the subscription stream.
- self._factory = RedisDirectTcpReplicationClientFactory(
- hs,
- outbound_redis_connection,
- channel_names=self._channels_to_subscribe_to,
- )
- hs.get_reactor().connectTCP(
- hs.config.redis.redis_host,
- hs.config.redis.redis_port,
- self._factory,
- timeout=30,
- bindAddress=None,
- )
- else:
- client_name = hs.get_instance_name()
- self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
- host = hs.config.worker.worker_replication_host
- port = hs.config.worker.worker_replication_port
- hs.get_reactor().connectTCP(
- host,
- port,
- self._factory,
- timeout=30,
- bindAddress=None,
- )
+ # Now create the factory/connection for the subscription stream.
+ self._factory = RedisDirectTcpReplicationClientFactory(
+ hs,
+ outbound_redis_connection,
+ channel_names=self._channels_to_subscribe_to,
+ )
+ hs.get_reactor().connectTCP(
+ hs.config.redis.redis_host,
+ hs.config.redis.redis_port,
+ self._factory,
+ timeout=30,
+ bindAddress=None,
+ )
def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams."""
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 26f4fa7cfd..14b6705862 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -98,6 +98,7 @@ class EventsStreamEventRow(BaseEventsStreamRow):
relates_to: Optional[str]
membership: Optional[str]
rejected: bool
+ outlier: bool
@attr.s(slots=True, frozen=True, auto_attribs=True)
|