diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 3bbf3c3569..20cb8a654f 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -100,10 +100,10 @@ class ReplicationDataHandler:
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
- self.store.process_replication_rows(stream_name, token, rows)
+ self.store.process_replication_rows(stream_name, instance_name, token, rows)
- async def on_position(self, stream_name: str, token: int):
- self.store.process_replication_rows(stream_name, token, [])
+ async def on_position(self, stream_name: str, instance_name: str, token: int):
+ self.store.process_replication_rows(stream_name, instance_name, token, [])
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index f58e384d17..c04f622816 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -341,37 +341,6 @@ class RemovePusherCommand(Command):
return " ".join((self.app_id, self.push_key, self.user_id))
-class InvalidateCacheCommand(Command):
- """Sent by the client to invalidate an upstream cache.
-
- THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
- NOT DISASTROUS IF WE DROP ON THE FLOOR.
-
- Mainly used to invalidate destination retry timing caches.
-
- Format::
-
- INVALIDATE_CACHE <cache_func> <keys_json>
-
- Where <keys_json> is a json list.
- """
-
- NAME = "INVALIDATE_CACHE"
-
- def __init__(self, cache_func, keys):
- self.cache_func = cache_func
- self.keys = keys
-
- @classmethod
- def from_line(cls, line):
- cache_func, keys_json = line.split(" ", 1)
-
- return cls(cache_func, json.loads(keys_json))
-
- def to_line(self):
- return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
-
-
class UserIpCommand(Command):
"""Sent periodically when a worker sees activity from a client.
@@ -439,7 +408,6 @@ _COMMANDS = (
UserSyncCommand,
FederationAckCommand,
RemovePusherCommand,
- InvalidateCacheCommand,
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
@@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = (
ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
RemovePusherCommand.NAME,
- InvalidateCacheCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b14a3d9fca..7c5d6c76e7 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,18 +15,7 @@
# limitations under the License.
import logging
-from typing import (
- Any,
- Callable,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Set,
- Tuple,
- TypeVar,
-)
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
from prometheus_client import Counter
@@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
FederationAckCommand,
- InvalidateCacheCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
@@ -171,7 +159,7 @@ class ReplicationCommandHandler:
return
for stream_name, stream in self._streams.items():
- current_token = stream.current_token()
+ current_token = stream.current_token(self._instance_name)
self.send_command(
PositionCommand(stream_name, self._instance_name, current_token)
)
@@ -210,18 +198,6 @@ class ReplicationCommandHandler:
self._notifier.on_new_replication_data()
- async def on_INVALIDATE_CACHE(
- self, conn: AbstractConnection, cmd: InvalidateCacheCommand
- ):
- invalidate_cache_counter.inc()
-
- if self._is_master:
- # We invalidate the cache locally, but then also stream that to other
- # workers.
- await self._store.invalidate_cache_and_stream(
- cmd.cache_func, tuple(cmd.keys)
- )
-
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
user_ip_cache_counter.inc()
@@ -295,7 +271,7 @@ class ReplicationCommandHandler:
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
- logger.debug("Received rdata %s -> %s", stream_name, token)
+ logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)
@@ -326,7 +302,7 @@ class ReplicationCommandHandler:
self._pending_batches.pop(stream_name, [])
# Find where we previously streamed up to.
- current_token = stream.current_token()
+ current_token = stream.current_token(cmd.instance_name)
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
@@ -363,7 +339,9 @@ class ReplicationCommandHandler:
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
# We've now caught up to position sent to us, notify handler.
- await self._replication_data_handler.on_position(stream_name, cmd.token)
+ await self._replication_data_handler.on_position(
+ cmd.stream_name, cmd.instance_name, cmd.token
+ )
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
@@ -491,12 +469,6 @@ class ReplicationCommandHandler:
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
- def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
- """Poke the master to invalidate a cache.
- """
- cmd = InvalidateCacheCommand(cache_func.__name__, keys)
- self.send_command(cmd)
-
def send_user_ip(
self,
user_id: str,
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index b690abedad..002171ce7c 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -25,7 +25,12 @@ from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
-from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
+from synapse.replication.tcp.streams import (
+ STREAMS_MAP,
+ CachesStream,
+ FederationStream,
+ Stream,
+)
from synapse.util.metrics import Measure
stream_updates_counter = Counter(
@@ -71,11 +76,16 @@ class ReplicationStreamer(object):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
+ self._instance_name = hs.get_instance_name()
self._replication_torture_level = hs.config.replication_torture_level
# Work out list of streams that this instance is the source of.
self.streams = [] # type: List[Stream]
+
+ # All workers can write to the cache invalidation stream.
+ self.streams.append(CachesStream(hs))
+
if hs.config.worker_app is None:
for stream in STREAMS_MAP.values():
if stream == FederationStream and hs.config.send_federation:
@@ -83,6 +93,10 @@ class ReplicationStreamer(object):
# has been disabled on the master.
continue
+ if stream == CachesStream:
+ # We've already added it above.
+ continue
+
self.streams.append(stream(hs))
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
@@ -145,7 +159,9 @@ class ReplicationStreamer(object):
random.shuffle(all_streams)
for stream in all_streams:
- if stream.last_token == stream.current_token():
+ if stream.last_token == stream.current_token(
+ self._instance_name
+ ):
continue
if self._replication_torture_level:
@@ -157,7 +173,7 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME,
stream.last_token,
- stream.current_token(),
+ stream.current_token(self._instance_name),
)
try:
updates, current_token, limited = await stream.get_updates()
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 084604e8b0..b48a6a3e91 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -95,20 +95,25 @@ class Stream(object):
def __init__(
self,
local_instance_name: str,
- current_token_function: Callable[[], Token],
+ current_token_function: Callable[[str], Token],
update_function: UpdateFunction,
):
"""Instantiate a Stream
- current_token_function and update_function are callbacks which should be
- implemented by subclasses.
+ `current_token_function` and `update_function` are callbacks which
+ should be implemented by subclasses.
- current_token_function is called to get the current token of the underlying
- stream. It is only meaningful on the process that is the source of the
- replication stream (ie, usually the master).
+ `current_token_function` takes an instance name, which is a writer to
+ the stream, and returns the position in the stream of the writer (as
+ viewed from the current process). On the writer process this is where
+ the writer has successfully written up to, whereas on other processes
+ this is the position which we have received updates up to over
+ replication. (Note that most streams have a single writer and so their
+ implementations ignore the instance name passed in).
- update_function is called to get updates for this stream between a pair of
- stream tokens. See the UpdateFunction type definition for more info.
+ `update_function` is called to get updates for this stream between a
+ pair of stream tokens. See the `UpdateFunction` type definition for more
+ info.
Args:
local_instance_name: The instance name of the current process
@@ -120,13 +125,13 @@ class Stream(object):
self.update_function = update_function
# The token from which we last asked for updates
- self.last_token = self.current_token()
+ self.last_token = self.current_token(self.local_instance_name)
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
- self.last_token = self.current_token()
+ self.last_token = self.current_token(self.local_instance_name)
async def get_updates(self) -> StreamUpdateResult:
"""Gets all updates since the last time this function was called (or
@@ -138,7 +143,7 @@ class Stream(object):
position in stream, and `limited` is whether there are more updates
to fetch.
"""
- current_token = self.current_token()
+ current_token = self.current_token(self.local_instance_name)
updates, current_token, limited = await self.get_updates_since(
self.local_instance_name, self.last_token, current_token
)
@@ -170,6 +175,16 @@ class Stream(object):
return updates, upto_token, limited
+def current_token_without_instance(
+ current_token: Callable[[], int]
+) -> Callable[[str], int]:
+ """Takes a current token callback function for a single writer stream
+ that doesn't take an instance name parameter and wraps it in a function that
+ does accept an instance name parameter but ignores it.
+ """
+ return lambda instance_name: current_token()
+
+
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
@@ -235,7 +250,7 @@ class BackfillStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_current_backfill_token,
+ current_token_without_instance(store.get_current_backfill_token),
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
@@ -271,7 +286,9 @@ class PresenceStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
- hs.get_instance_name(), store.get_current_presence_token, update_function
+ hs.get_instance_name(),
+ current_token_without_instance(store.get_current_presence_token),
+ update_function,
)
@@ -296,7 +313,9 @@ class TypingStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
- hs.get_instance_name(), typing_handler.get_current_token, update_function
+ hs.get_instance_name(),
+ current_token_without_instance(typing_handler.get_current_token),
+ update_function,
)
@@ -319,7 +338,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_max_receipt_stream_id,
+ current_token_without_instance(store.get_max_receipt_stream_id),
db_query_to_update_function(store.get_all_updated_receipts),
)
@@ -339,7 +358,7 @@ class PushRulesStream(Stream):
hs.get_instance_name(), self._current_token, self._update_function
)
- def _current_token(self) -> int:
+ def _current_token(self, instance_name: str) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
@@ -373,7 +392,7 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
- store.get_pushers_stream_token,
+ current_token_without_instance(store.get_pushers_stream_token),
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
@@ -402,12 +421,26 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
- store = hs.get_datastore()
+ self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_cache_stream_token,
- db_query_to_update_function(store.get_all_updated_caches),
+ self.store.get_cache_stream_token,
+ self._update_function,
+ )
+
+ async def _update_function(
+ self, instance_name: str, from_token: int, upto_token: int, limit: int
+ ):
+ rows = await self.store.get_all_updated_caches(
+ instance_name, from_token, upto_token, limit
)
+ updates = [(row[0], row[1:]) for row in rows]
+ limited = False
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
class PublicRoomsStream(Stream):
@@ -431,7 +464,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_current_public_room_stream_id,
+ current_token_without_instance(store.get_current_public_room_stream_id),
db_query_to_update_function(store.get_all_new_public_rooms),
)
@@ -452,7 +485,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_device_stream_token,
+ current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
@@ -470,7 +503,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_to_device_stream_token,
+ current_token_without_instance(store.get_to_device_stream_token),
db_query_to_update_function(store.get_all_new_device_messages),
)
@@ -490,7 +523,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_max_account_data_stream_id,
+ current_token_without_instance(store.get_max_account_data_stream_id),
db_query_to_update_function(store.get_all_updated_tags),
)
@@ -510,7 +543,7 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- self.store.get_max_account_data_stream_id,
+ current_token_without_instance(self.store.get_max_account_data_stream_id),
db_query_to_update_function(self._update_function),
)
@@ -541,7 +574,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_group_stream_token,
+ current_token_without_instance(store.get_group_stream_token),
db_query_to_update_function(store.get_all_groups_changes),
)
@@ -559,7 +592,7 @@ class UserSignatureStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_device_stream_token,
+ current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
),
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 890e75d827..f370390331 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -20,7 +20,7 @@ from typing import List, Tuple, Type
import attr
-from ._base import Stream, StreamUpdateResult, Token
+from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
"""Handling of the 'events' replication stream
@@ -119,7 +119,7 @@ class EventsStream(Stream):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- self._store.get_current_events_token,
+ current_token_without_instance(self._store.get_current_events_token),
self._update_function,
)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index b0505b8a2c..9bcd13b009 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,11 @@
# limitations under the License.
from collections import namedtuple
-from synapse.replication.tcp.streams._base import Stream, make_http_update_function
+from synapse.replication.tcp.streams._base import (
+ Stream,
+ current_token_without_instance,
+ make_http_update_function,
+)
class FederationStream(Stream):
@@ -41,7 +45,9 @@ class FederationStream(Stream):
# will be a real FederationSender, which has stubs for current_token and
# get_replication_rows.)
federation_sender = hs.get_federation_sender()
- current_token = federation_sender.get_current_token
+ current_token = current_token_without_instance(
+ federation_sender.get_current_token
+ )
update_function = federation_sender.get_replication_rows
elif hs.should_send_federation():
@@ -58,7 +64,7 @@ class FederationStream(Stream):
super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod
- def _stub_current_token():
+ def _stub_current_token(instance_name: str) -> int:
# dummy current-token method for use on workers
return 0
|