diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 3bbf3c3569..28826302f5 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,12 +16,17 @@
"""
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Tuple
from twisted.internet.protocol import ReconnectingClientFactory
-from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.api.constants import EventTypes
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+from synapse.replication.tcp.streams.events import (
+ EventsStream,
+ EventsStreamEventRow,
+ EventsStreamRow,
+)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -83,8 +88,10 @@ class ReplicationDataHandler:
to handle updates in additional ways.
"""
- def __init__(self, store: BaseSlavedStore):
- self.store = store
+ def __init__(self, hs: "HomeServer"):
+ self.store = hs.get_datastore()
+ self.pusher_pool = hs.get_pusherpool()
+ self.notifier = hs.get_notifier()
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@@ -100,10 +107,32 @@ class ReplicationDataHandler:
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
- self.store.process_replication_rows(stream_name, token, rows)
-
- async def on_position(self, stream_name: str, token: int):
- self.store.process_replication_rows(stream_name, token, [])
+ self.store.process_replication_rows(stream_name, instance_name, token, rows)
+
+ if stream_name == EventsStream.NAME:
+ # We shouldn't get multiple rows per token for events stream, so
+ # we don't need to optimise this for multiple rows.
+ for row in rows:
+ if row.type != EventsStreamEventRow.TypeId:
+ continue
+ assert isinstance(row, EventsStreamRow)
+
+ event = await self.store.get_event(
+ row.data.event_id, allow_rejected=True
+ )
+ if event.rejected_reason:
+ continue
+
+ extra_users = () # type: Tuple[str, ...]
+ if event.type == EventTypes.Member:
+ extra_users = (event.state_key,)
+ max_token = self.store.get_room_max_stream_ordering()
+ self.notifier.on_new_room_event(event, token, max_token, extra_users)
+
+ await self.pusher_pool.on_new_notifications(token, token)
+
+ async def on_position(self, stream_name: str, instance_name: str, token: int):
+ self.store.process_replication_rows(stream_name, instance_name, token, [])
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index f58e384d17..c04f622816 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -341,37 +341,6 @@ class RemovePusherCommand(Command):
return " ".join((self.app_id, self.push_key, self.user_id))
-class InvalidateCacheCommand(Command):
- """Sent by the client to invalidate an upstream cache.
-
- THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
- NOT DISASTROUS IF WE DROP ON THE FLOOR.
-
- Mainly used to invalidate destination retry timing caches.
-
- Format::
-
- INVALIDATE_CACHE <cache_func> <keys_json>
-
- Where <keys_json> is a json list.
- """
-
- NAME = "INVALIDATE_CACHE"
-
- def __init__(self, cache_func, keys):
- self.cache_func = cache_func
- self.keys = keys
-
- @classmethod
- def from_line(cls, line):
- cache_func, keys_json = line.split(" ", 1)
-
- return cls(cache_func, json.loads(keys_json))
-
- def to_line(self):
- return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
-
-
class UserIpCommand(Command):
"""Sent periodically when a worker sees activity from a client.
@@ -439,7 +408,6 @@ _COMMANDS = (
UserSyncCommand,
FederationAckCommand,
RemovePusherCommand,
- InvalidateCacheCommand,
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
@@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = (
ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
RemovePusherCommand.NAME,
- InvalidateCacheCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 4328b38e9d..acfa66a7a8 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,18 +15,7 @@
# limitations under the License.
import logging
-from typing import (
- Any,
- Callable,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Set,
- Tuple,
- TypeVar,
-)
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
from prometheus_client import Counter
@@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
FederationAckCommand,
- InvalidateCacheCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
@@ -48,7 +36,12 @@ from synapse.replication.tcp.commands import (
UserSyncCommand,
)
from synapse.replication.tcp.protocol import AbstractConnection
-from synapse.replication.tcp.streams import STREAMS_MAP, Stream
+from synapse.replication.tcp.streams import (
+ STREAMS_MAP,
+ CachesStream,
+ FederationStream,
+ Stream,
+)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -85,6 +78,26 @@ class ReplicationCommandHandler:
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
+ # List of streams that this instance is the source of
+ self._streams_to_replicate = [] # type: List[Stream]
+
+ for stream in self._streams.values():
+ if stream.NAME == CachesStream.NAME:
+ # All workers can write to the cache invalidation stream.
+ self._streams_to_replicate.append(stream)
+ continue
+
+ # Only add any other streams if we're on master.
+ if hs.config.worker_app is not None:
+ continue
+
+ if stream.NAME == FederationStream.NAME and hs.config.send_federation:
+ # We only support federation stream if federation sending
+ # has been disabled on the master.
+ continue
+
+ self._streams_to_replicate.append(stream)
+
self._position_linearizer = Linearizer(
"replication_position", clock=self._clock
)
@@ -162,16 +175,33 @@ class ReplicationCommandHandler:
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
+ def get_streams(self) -> Dict[str, Stream]:
+ """Get a map from stream name to all streams.
+ """
+ return self._streams
+
+ def get_streams_to_replicate(self) -> List[Stream]:
+ """Get a list of streams that this instances replicates.
+ """
+ return self._streams_to_replicate
+
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
- # We only want to announce positions by the writer of the streams.
- # Currently this is just the master process.
- if not self._is_master:
- return
+ self.send_positions_to_connection(conn)
- for stream_name, stream in self._streams.items():
- current_token = stream.current_token()
+ def send_positions_to_connection(self, conn: AbstractConnection):
+ """Send current position of all streams this process is source of to
+ the connection.
+ """
+
+ # We respond with current position of all streams this instance
+ # replicates.
+ for stream in self.get_streams_to_replicate():
self.send_command(
- PositionCommand(stream_name, self._instance_name, current_token)
+ PositionCommand(
+ stream.NAME,
+ self._instance_name,
+ stream.current_token(self._instance_name),
+ )
)
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
@@ -208,18 +238,6 @@ class ReplicationCommandHandler:
self._notifier.on_new_replication_data()
- async def on_INVALIDATE_CACHE(
- self, conn: AbstractConnection, cmd: InvalidateCacheCommand
- ):
- invalidate_cache_counter.inc()
-
- if self._is_master:
- # We invalidate the cache locally, but then also stream that to other
- # workers.
- await self._store.invalidate_cache_and_stream(
- cmd.cache_func, tuple(cmd.keys)
- )
-
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
user_ip_cache_counter.inc()
@@ -293,7 +311,7 @@ class ReplicationCommandHandler:
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
- logger.debug("Received rdata %s -> %s", stream_name, token)
+ logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)
@@ -324,7 +342,7 @@ class ReplicationCommandHandler:
self._pending_batches.pop(stream_name, [])
# Find where we previously streamed up to.
- current_token = stream.current_token()
+ current_token = stream.current_token(cmd.instance_name)
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
@@ -361,7 +379,9 @@ class ReplicationCommandHandler:
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
# We've now caught up to position sent to us, notify handler.
- await self._replication_data_handler.on_position(stream_name, cmd.token)
+ await self._replication_data_handler.on_position(
+ cmd.stream_name, cmd.instance_name, cmd.token
+ )
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
@@ -489,12 +509,6 @@ class ReplicationCommandHandler:
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
- def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
- """Poke the master to invalidate a cache.
- """
- cmd = InvalidateCacheCommand(cache_func.__name__, keys)
- self.send_command(cmd)
-
def send_user_ip(
self,
user_id: str,
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 55bfa71dfd..e776b63183 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -70,7 +70,6 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
logger.info("Connected to redis")
super().connectionMade()
run_as_background_process("subscribe-replication", self._send_subscribe)
- self.handler.new_connection(self)
async def _send_subscribe(self):
# it's important to make sure that we only send the REPLICATE command once we
@@ -81,9 +80,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
+ self.handler.new_connection(self)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
+ # We send out our positions when there is a new connection in case the
+ # other side missed updates. We do this for Redis connections as the
+ # otherside won't know we've connected and so won't issue a REPLICATE.
+ self.handler.send_positions_to_connection(self)
+
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
"""
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 33d2f589ac..41569305df 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,7 +17,6 @@
import logging
import random
-from typing import Dict, List
from prometheus_client import Counter
@@ -25,7 +24,6 @@ from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
-from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
from synapse.util.metrics import Measure
stream_updates_counter = Counter(
@@ -71,26 +69,11 @@ class ReplicationStreamer(object):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
+ self._instance_name = hs.get_instance_name()
self._replication_torture_level = hs.config.replication_torture_level
- # Work out list of streams that this instance is the source of.
- self.streams = [] # type: List[Stream]
- if hs.config.worker_app is None:
- for stream in STREAMS_MAP.values():
- if stream == FederationStream and hs.config.send_federation:
- # We only support federation stream if federation sending
- # hase been disabled on the master.
- continue
-
- self.streams.append(stream(hs))
-
- self.streams_by_name = {stream.NAME: stream for stream in self.streams}
-
- # Only bother registering the notifier callback if we have streams to
- # publish.
- if self.streams:
- self.notifier.add_replication_callback(self.on_notifier_poke)
+ self.notifier.add_replication_callback(self.on_notifier_poke)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
@@ -98,10 +81,8 @@ class ReplicationStreamer(object):
self.command_handler = hs.get_tcp_replication()
- def get_streams(self) -> Dict[str, Stream]:
- """Get a mapp from stream name to stream instance.
- """
- return self.streams_by_name
+ # Set of streams to replicate.
+ self.streams = self.command_handler.get_streams_to_replicate()
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
@@ -145,7 +126,9 @@ class ReplicationStreamer(object):
random.shuffle(all_streams)
for stream in all_streams:
- if stream.last_token == stream.current_token():
+ if stream.last_token == stream.current_token(
+ self._instance_name
+ ):
continue
if self._replication_torture_level:
@@ -157,7 +140,7 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME,
stream.last_token,
- stream.current_token(),
+ stream.current_token(self._instance_name),
)
try:
updates, current_token, limited = await stream.get_updates()
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b0f87c365b..d42aaff055 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,14 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import heapq
import logging
from collections import namedtuple
-from typing import Any, Awaitable, Callable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+)
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
+if TYPE_CHECKING:
+ import synapse.server
+
logger = logging.getLogger(__name__)
# the number of rows to request from an update_function.
@@ -37,7 +50,7 @@ Token = int
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
# just a row from a database query, though this is dependent on the stream in question.
#
-StreamRow = Tuple
+StreamRow = TypeVar("StreamRow", bound=Tuple)
# The type returned by the update_function of a stream, as well as get_updates(),
# get_updates_since, etc.
@@ -95,19 +108,25 @@ class Stream(object):
def __init__(
self,
local_instance_name: str,
- current_token_function: Callable[[], Token],
+ current_token_function: Callable[[str], Token],
update_function: UpdateFunction,
):
"""Instantiate a Stream
- current_token_function and update_function are callbacks which should be
- implemented by subclasses.
+ `current_token_function` and `update_function` are callbacks which
+ should be implemented by subclasses.
- current_token_function is called to get the current token of the underlying
- stream.
+ `current_token_function` takes an instance name, which is a writer to
+ the stream, and returns the position in the stream of the writer (as
+ viewed from the current process). On the writer process this is where
+ the writer has successfully written up to, whereas on other processes
+ this is the position which we have received updates up to over
+ replication. (Note that most streams have a single writer and so their
+ implementations ignore the instance name passed in).
- update_function is called to get updates for this stream between a pair of
- stream tokens. See the UpdateFunction type definition for more info.
+ `update_function` is called to get updates for this stream between a
+ pair of stream tokens. See the `UpdateFunction` type definition for more
+ info.
Args:
local_instance_name: The instance name of the current process
@@ -119,13 +138,13 @@ class Stream(object):
self.update_function = update_function
# The token from which we last asked for updates
- self.last_token = self.current_token()
+ self.last_token = self.current_token(self.local_instance_name)
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
- self.last_token = self.current_token()
+ self.last_token = self.current_token(self.local_instance_name)
async def get_updates(self) -> StreamUpdateResult:
"""Gets all updates since the last time this function was called (or
@@ -137,7 +156,7 @@ class Stream(object):
position in stream, and `limited` is whether there are more updates
to fetch.
"""
- current_token = self.current_token()
+ current_token = self.current_token(self.local_instance_name)
updates, current_token, limited = await self.get_updates_since(
self.local_instance_name, self.last_token, current_token
)
@@ -169,6 +188,16 @@ class Stream(object):
return updates, upto_token, limited
+def current_token_without_instance(
+ current_token: Callable[[], int]
+) -> Callable[[str], int]:
+ """Takes a current token callback function for a single writer stream
+ that doesn't take an instance name parameter and wraps it in a function that
+ does accept an instance name parameter but ignores it.
+ """
+ return lambda instance_name: current_token()
+
+
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
@@ -234,7 +263,7 @@ class BackfillStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_current_backfill_token,
+ current_token_without_instance(store.get_current_backfill_token),
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
@@ -270,7 +299,9 @@ class PresenceStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
- hs.get_instance_name(), store.get_current_presence_token, update_function
+ hs.get_instance_name(),
+ current_token_without_instance(store.get_current_presence_token),
+ update_function,
)
@@ -295,7 +326,9 @@ class TypingStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
- hs.get_instance_name(), typing_handler.get_current_token, update_function
+ hs.get_instance_name(),
+ current_token_without_instance(typing_handler.get_current_token),
+ update_function,
)
@@ -318,7 +351,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_max_receipt_stream_id,
+ current_token_without_instance(store.get_max_receipt_stream_id),
db_query_to_update_function(store.get_all_updated_receipts),
)
@@ -338,7 +371,7 @@ class PushRulesStream(Stream):
hs.get_instance_name(), self._current_token, self._update_function
)
- def _current_token(self) -> int:
+ def _current_token(self, instance_name: str) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
@@ -372,7 +405,7 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
- store.get_pushers_stream_token,
+ current_token_without_instance(store.get_pushers_stream_token),
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
@@ -401,13 +434,27 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
- store = hs.get_datastore()
+ self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_cache_stream_token,
- db_query_to_update_function(store.get_all_updated_caches),
+ self.store.get_cache_stream_token,
+ self._update_function,
)
+ async def _update_function(
+ self, instance_name: str, from_token: int, upto_token: int, limit: int
+ ):
+ rows = await self.store.get_all_updated_caches(
+ instance_name, from_token, upto_token, limit
+ )
+ updates = [(row[0], row[1:]) for row in rows]
+ limited = False
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
class PublicRoomsStream(Stream):
"""The public rooms list changed
@@ -430,7 +477,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_current_public_room_stream_id,
+ current_token_without_instance(store.get_current_public_room_stream_id),
db_query_to_update_function(store.get_all_new_public_rooms),
)
@@ -451,7 +498,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_device_stream_token,
+ current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
@@ -469,7 +516,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_to_device_stream_token,
+ current_token_without_instance(store.get_to_device_stream_token),
db_query_to_update_function(store.get_all_new_device_messages),
)
@@ -489,7 +536,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_max_account_data_stream_id,
+ current_token_without_instance(store.get_max_account_data_stream_id),
db_query_to_update_function(store.get_all_updated_tags),
)
@@ -499,32 +546,63 @@ class AccountDataStream(Stream):
"""
AccountDataStreamRow = namedtuple(
- "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
+ "AccountDataStream",
+ ("user_id", "room_id", "data_type"), # str # Optional[str] # str
)
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- self.store.get_max_account_data_stream_id,
- db_query_to_update_function(self._update_function),
+ current_token_without_instance(self.store.get_max_account_data_stream_id),
+ self._update_function,
)
- async def _update_function(self, from_token, to_token, limit):
- global_results, room_results = await self.store.get_all_updated_account_data(
- from_token, from_token, to_token, limit
+ async def _update_function(
+ self, instance_name: str, from_token: int, to_token: int, limit: int
+ ) -> StreamUpdateResult:
+ limited = False
+ global_results = await self.store.get_updated_global_account_data(
+ from_token, to_token, limit
)
- results = list(room_results)
- results.extend(
- (stream_id, user_id, None, account_data_type)
+ # if the global results hit the limit, we'll need to limit the room results to
+ # the same stream token.
+ if len(global_results) >= limit:
+ to_token = global_results[-1][0]
+ limited = True
+
+ room_results = await self.store.get_updated_room_account_data(
+ from_token, to_token, limit
+ )
+
+ # likewise, if the room results hit the limit, limit the global results to
+ # the same stream token.
+ if len(room_results) >= limit:
+ to_token = room_results[-1][0]
+ limited = True
+
+ # convert the global results to the right format, and limit them to the to_token
+ # at the same time
+ global_rows = (
+ (stream_id, (user_id, None, account_data_type))
for stream_id, user_id, account_data_type in global_results
+ if stream_id <= to_token
+ )
+
+ # we know that the room_results are already limited to `to_token` so no need
+ # for a check on `stream_id` here.
+ room_rows = (
+ (stream_id, (user_id, room_id, account_data_type))
+ for stream_id, user_id, room_id, account_data_type in room_results
)
- return results
+ # we need to return a sorted list, so merge them together.
+ updates = list(heapq.merge(room_rows, global_rows))
+ return updates, to_token, limited
class GroupServerStream(Stream):
@@ -540,7 +618,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_group_stream_token,
+ current_token_without_instance(store.get_group_stream_token),
db_query_to_update_function(store.get_all_groups_changes),
)
@@ -558,7 +636,7 @@ class UserSignatureStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_device_stream_token,
+ current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
),
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 890e75d827..f370390331 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -20,7 +20,7 @@ from typing import List, Tuple, Type
import attr
-from ._base import Stream, StreamUpdateResult, Token
+from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
"""Handling of the 'events' replication stream
@@ -119,7 +119,7 @@ class EventsStream(Stream):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- self._store.get_current_events_token,
+ current_token_without_instance(self._store.get_current_events_token),
self._update_function,
)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index e8bd52e389..9bcd13b009 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,11 @@
# limitations under the License.
from collections import namedtuple
-from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
+from synapse.replication.tcp.streams._base import (
+ Stream,
+ current_token_without_instance,
+ make_http_update_function,
+)
class FederationStream(Stream):
@@ -35,21 +39,35 @@ class FederationStream(Stream):
ROW_TYPE = FederationStreamRow
def __init__(self, hs):
- # Not all synapse instances will have a federation sender instance,
- # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
- # so we stub the stream out when that is the case.
- if hs.config.worker_app is None or hs.should_send_federation():
+ if hs.config.worker_app is None:
+ # master process: get updates from the FederationRemoteSendQueue.
+ # (if the master is configured to send federation itself, federation_sender
+ # will be a real FederationSender, which has stubs for current_token and
+ # get_replication_rows.)
federation_sender = hs.get_federation_sender()
- current_token = federation_sender.get_current_token
- update_function = db_query_to_update_function(
- federation_sender.get_replication_rows
+ current_token = current_token_without_instance(
+ federation_sender.get_current_token
)
+ update_function = federation_sender.get_replication_rows
+
+ elif hs.should_send_federation():
+ # federation sender: Query master process
+ update_function = make_http_update_function(hs, self.NAME)
+ current_token = self._stub_current_token
+
else:
- current_token = lambda: 0
+ # other worker: stub out the update function (we're not interested in
+ # any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function
+ current_token = self._stub_current_token
super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod
+ def _stub_current_token(instance_name: str) -> int:
+ # dummy current-token method for use on workers
+ return 0
+
+ @staticmethod
async def _stub_update_function(instance_name, from_token, upto_token, limit):
return [], upto_token, False
|