diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 4328b38e9d..acfa66a7a8 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,18 +15,7 @@
# limitations under the License.
import logging
-from typing import (
- Any,
- Callable,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Set,
- Tuple,
- TypeVar,
-)
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
from prometheus_client import Counter
@@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
FederationAckCommand,
- InvalidateCacheCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
@@ -48,7 +36,12 @@ from synapse.replication.tcp.commands import (
UserSyncCommand,
)
from synapse.replication.tcp.protocol import AbstractConnection
-from synapse.replication.tcp.streams import STREAMS_MAP, Stream
+from synapse.replication.tcp.streams import (
+ STREAMS_MAP,
+ CachesStream,
+ FederationStream,
+ Stream,
+)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -85,6 +78,26 @@ class ReplicationCommandHandler:
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
+ # List of streams that this instance is the source of
+ self._streams_to_replicate = [] # type: List[Stream]
+
+ for stream in self._streams.values():
+ if stream.NAME == CachesStream.NAME:
+ # All workers can write to the cache invalidation stream.
+ self._streams_to_replicate.append(stream)
+ continue
+
+ # Only add any other streams if we're on master.
+ if hs.config.worker_app is not None:
+ continue
+
+ if stream.NAME == FederationStream.NAME and hs.config.send_federation:
+ # We only support federation stream if federation sending
+ # has been disabled on the master.
+ continue
+
+ self._streams_to_replicate.append(stream)
+
self._position_linearizer = Linearizer(
"replication_position", clock=self._clock
)
@@ -162,16 +175,33 @@ class ReplicationCommandHandler:
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
+ def get_streams(self) -> Dict[str, Stream]:
+ """Get a map from stream name to all streams.
+ """
+ return self._streams
+
+ def get_streams_to_replicate(self) -> List[Stream]:
+ """Get a list of streams that this instances replicates.
+ """
+ return self._streams_to_replicate
+
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
- # We only want to announce positions by the writer of the streams.
- # Currently this is just the master process.
- if not self._is_master:
- return
+ self.send_positions_to_connection(conn)
- for stream_name, stream in self._streams.items():
- current_token = stream.current_token()
+ def send_positions_to_connection(self, conn: AbstractConnection):
+ """Send current position of all streams this process is source of to
+ the connection.
+ """
+
+ # We respond with current position of all streams this instance
+ # replicates.
+ for stream in self.get_streams_to_replicate():
self.send_command(
- PositionCommand(stream_name, self._instance_name, current_token)
+ PositionCommand(
+ stream.NAME,
+ self._instance_name,
+ stream.current_token(self._instance_name),
+ )
)
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
@@ -208,18 +238,6 @@ class ReplicationCommandHandler:
self._notifier.on_new_replication_data()
- async def on_INVALIDATE_CACHE(
- self, conn: AbstractConnection, cmd: InvalidateCacheCommand
- ):
- invalidate_cache_counter.inc()
-
- if self._is_master:
- # We invalidate the cache locally, but then also stream that to other
- # workers.
- await self._store.invalidate_cache_and_stream(
- cmd.cache_func, tuple(cmd.keys)
- )
-
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
user_ip_cache_counter.inc()
@@ -293,7 +311,7 @@ class ReplicationCommandHandler:
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
- logger.debug("Received rdata %s -> %s", stream_name, token)
+ logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)
@@ -324,7 +342,7 @@ class ReplicationCommandHandler:
self._pending_batches.pop(stream_name, [])
# Find where we previously streamed up to.
- current_token = stream.current_token()
+ current_token = stream.current_token(cmd.instance_name)
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
@@ -361,7 +379,9 @@ class ReplicationCommandHandler:
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
# We've now caught up to position sent to us, notify handler.
- await self._replication_data_handler.on_position(stream_name, cmd.token)
+ await self._replication_data_handler.on_position(
+ cmd.stream_name, cmd.instance_name, cmd.token
+ )
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
@@ -489,12 +509,6 @@ class ReplicationCommandHandler:
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
- def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
- """Poke the master to invalidate a cache.
- """
- cmd = InvalidateCacheCommand(cache_func.__name__, keys)
- self.send_command(cmd)
-
def send_user_ip(
self,
user_id: str,
|