diff options
author | Erik Johnston <erik@matrix.org> | 2020-05-07 13:51:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-07 13:51:08 +0100 |
commit | d7983b63a6746d92225295f1e9d521f847cf8ba7 (patch) | |
tree | 13d581210d94c26bd75036592996f6f53f7d4bb2 /synapse/replication/tcp | |
parent | Merge pull request #7398 from Starbix/alpine-3.11 (diff) | |
download | synapse-d7983b63a6746d92225295f1e9d521f847cf8ba7.tar.xz |
Support any process writing to cache invalidation stream. (#7436)
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r-- | synapse/replication/tcp/client.py | 6 | ||||
-rw-r--r-- | synapse/replication/tcp/commands.py | 33 | ||||
-rw-r--r-- | synapse/replication/tcp/handler.py | 42 | ||||
-rw-r--r-- | synapse/replication/tcp/resource.py | 22 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 87 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/events.py | 4 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/federation.py | 12 |
7 files changed, 100 insertions, 106 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 3bbf3c3569..20cb8a654f 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -100,10 +100,10 @@ class ReplicationDataHandler: token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - self.store.process_replication_rows(stream_name, token, rows) + self.store.process_replication_rows(stream_name, instance_name, token, rows) - async def on_position(self, stream_name: str, token: int): - self.store.process_replication_rows(stream_name, token, []) + async def on_position(self, stream_name: str, instance_name: str, token: int): + self.store.process_replication_rows(stream_name, instance_name, token, []) def on_remote_server_up(self, server: str): """Called when get a new REMOTE_SERVER_UP command.""" diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index f58e384d17..c04f622816 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -341,37 +341,6 @@ class RemovePusherCommand(Command): return " ".join((self.app_id, self.push_key, self.user_id)) -class InvalidateCacheCommand(Command): - """Sent by the client to invalidate an upstream cache. - - THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE - NOT DISASTROUS IF WE DROP ON THE FLOOR. - - Mainly used to invalidate destination retry timing caches. - - Format:: - - INVALIDATE_CACHE <cache_func> <keys_json> - - Where <keys_json> is a json list. - """ - - NAME = "INVALIDATE_CACHE" - - def __init__(self, cache_func, keys): - self.cache_func = cache_func - self.keys = keys - - @classmethod - def from_line(cls, line): - cache_func, keys_json = line.split(" ", 1) - - return cls(cache_func, json.loads(keys_json)) - - def to_line(self): - return " ".join((self.cache_func, _json_encoder.encode(self.keys))) - - class UserIpCommand(Command): """Sent periodically when a worker sees activity from a client. @@ -439,7 +408,6 @@ _COMMANDS = ( UserSyncCommand, FederationAckCommand, RemovePusherCommand, - InvalidateCacheCommand, UserIpCommand, RemoteServerUpCommand, ClearUserSyncsCommand, @@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = ( ClearUserSyncsCommand.NAME, FederationAckCommand.NAME, RemovePusherCommand.NAME, - InvalidateCacheCommand.NAME, UserIpCommand.NAME, ErrorCommand.NAME, RemoteServerUpCommand.NAME, diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b14a3d9fca..7c5d6c76e7 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -15,18 +15,7 @@ # limitations under the License. import logging -from typing import ( - Any, - Callable, - Dict, - Iterable, - Iterator, - List, - Optional, - Set, - Tuple, - TypeVar, -) +from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar from prometheus_client import Counter @@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, Command, FederationAckCommand, - InvalidateCacheCommand, PositionCommand, RdataCommand, RemoteServerUpCommand, @@ -171,7 +159,7 @@ class ReplicationCommandHandler: return for stream_name, stream in self._streams.items(): - current_token = stream.current_token() + current_token = stream.current_token(self._instance_name) self.send_command( PositionCommand(stream_name, self._instance_name, current_token) ) @@ -210,18 +198,6 @@ class ReplicationCommandHandler: self._notifier.on_new_replication_data() - async def on_INVALIDATE_CACHE( - self, conn: AbstractConnection, cmd: InvalidateCacheCommand - ): - invalidate_cache_counter.inc() - - if self._is_master: - # We invalidate the cache locally, but then also stream that to other - # workers. - await self._store.invalidate_cache_and_stream( - cmd.cache_func, tuple(cmd.keys) - ) - async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): user_ip_cache_counter.inc() @@ -295,7 +271,7 @@ class ReplicationCommandHandler: rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - logger.debug("Received rdata %s -> %s", stream_name, token) + logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token) await self._replication_data_handler.on_rdata( stream_name, instance_name, token, rows ) @@ -326,7 +302,7 @@ class ReplicationCommandHandler: self._pending_batches.pop(stream_name, []) # Find where we previously streamed up to. - current_token = stream.current_token() + current_token = stream.current_token(cmd.instance_name) # If the position token matches our current token then we're up to # date and there's nothing to do. Otherwise, fetch all updates @@ -363,7 +339,9 @@ class ReplicationCommandHandler: logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) # We've now caught up to position sent to us, notify handler. - await self._replication_data_handler.on_position(stream_name, cmd.token) + await self._replication_data_handler.on_position( + cmd.stream_name, cmd.instance_name, cmd.token + ) self._streams_by_connection.setdefault(conn, set()).add(stream_name) @@ -491,12 +469,6 @@ class ReplicationCommandHandler: cmd = RemovePusherCommand(app_id, push_key, user_id) self.send_command(cmd) - def send_invalidate_cache(self, cache_func: Callable, keys: tuple): - """Poke the master to invalidate a cache. - """ - cmd = InvalidateCacheCommand(cache_func.__name__, keys) - self.send_command(cmd) - def send_user_ip( self, user_id: str, diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index b690abedad..002171ce7c 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -25,7 +25,12 @@ from twisted.internet.protocol import Factory from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol -from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream +from synapse.replication.tcp.streams import ( + STREAMS_MAP, + CachesStream, + FederationStream, + Stream, +) from synapse.util.metrics import Measure stream_updates_counter = Counter( @@ -71,11 +76,16 @@ class ReplicationStreamer(object): self.store = hs.get_datastore() self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self._instance_name = hs.get_instance_name() self._replication_torture_level = hs.config.replication_torture_level # Work out list of streams that this instance is the source of. self.streams = [] # type: List[Stream] + + # All workers can write to the cache invalidation stream. + self.streams.append(CachesStream(hs)) + if hs.config.worker_app is None: for stream in STREAMS_MAP.values(): if stream == FederationStream and hs.config.send_federation: @@ -83,6 +93,10 @@ class ReplicationStreamer(object): # has been disabled on the master. continue + if stream == CachesStream: + # We've already added it above. + continue + self.streams.append(stream(hs)) self.streams_by_name = {stream.NAME: stream for stream in self.streams} @@ -145,7 +159,9 @@ class ReplicationStreamer(object): random.shuffle(all_streams) for stream in all_streams: - if stream.last_token == stream.current_token(): + if stream.last_token == stream.current_token( + self._instance_name + ): continue if self._replication_torture_level: @@ -157,7 +173,7 @@ class ReplicationStreamer(object): "Getting stream: %s: %s -> %s", stream.NAME, stream.last_token, - stream.current_token(), + stream.current_token(self._instance_name), ) try: updates, current_token, limited = await stream.get_updates() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 084604e8b0..b48a6a3e91 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -95,20 +95,25 @@ class Stream(object): def __init__( self, local_instance_name: str, - current_token_function: Callable[[], Token], + current_token_function: Callable[[str], Token], update_function: UpdateFunction, ): """Instantiate a Stream - current_token_function and update_function are callbacks which should be - implemented by subclasses. + `current_token_function` and `update_function` are callbacks which + should be implemented by subclasses. - current_token_function is called to get the current token of the underlying - stream. It is only meaningful on the process that is the source of the - replication stream (ie, usually the master). + `current_token_function` takes an instance name, which is a writer to + the stream, and returns the position in the stream of the writer (as + viewed from the current process). On the writer process this is where + the writer has successfully written up to, whereas on other processes + this is the position which we have received updates up to over + replication. (Note that most streams have a single writer and so their + implementations ignore the instance name passed in). - update_function is called to get updates for this stream between a pair of - stream tokens. See the UpdateFunction type definition for more info. + `update_function` is called to get updates for this stream between a + pair of stream tokens. See the `UpdateFunction` type definition for more + info. Args: local_instance_name: The instance name of the current process @@ -120,13 +125,13 @@ class Stream(object): self.update_function = update_function # The token from which we last asked for updates - self.last_token = self.current_token() + self.last_token = self.current_token(self.local_instance_name) def discard_updates_and_advance(self): """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. """ - self.last_token = self.current_token() + self.last_token = self.current_token(self.local_instance_name) async def get_updates(self) -> StreamUpdateResult: """Gets all updates since the last time this function was called (or @@ -138,7 +143,7 @@ class Stream(object): position in stream, and `limited` is whether there are more updates to fetch. """ - current_token = self.current_token() + current_token = self.current_token(self.local_instance_name) updates, current_token, limited = await self.get_updates_since( self.local_instance_name, self.last_token, current_token ) @@ -170,6 +175,16 @@ class Stream(object): return updates, upto_token, limited +def current_token_without_instance( + current_token: Callable[[], int] +) -> Callable[[str], int]: + """Takes a current token callback function for a single writer stream + that doesn't take an instance name parameter and wraps it in a function that + does accept an instance name parameter but ignores it. + """ + return lambda instance_name: current_token() + + def db_query_to_update_function( query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] ) -> UpdateFunction: @@ -235,7 +250,7 @@ class BackfillStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_current_backfill_token, + current_token_without_instance(store.get_current_backfill_token), db_query_to_update_function(store.get_all_new_backfill_event_rows), ) @@ -271,7 +286,9 @@ class PresenceStream(Stream): update_function = make_http_update_function(hs, self.NAME) super().__init__( - hs.get_instance_name(), store.get_current_presence_token, update_function + hs.get_instance_name(), + current_token_without_instance(store.get_current_presence_token), + update_function, ) @@ -296,7 +313,9 @@ class TypingStream(Stream): update_function = make_http_update_function(hs, self.NAME) super().__init__( - hs.get_instance_name(), typing_handler.get_current_token, update_function + hs.get_instance_name(), + current_token_without_instance(typing_handler.get_current_token), + update_function, ) @@ -319,7 +338,7 @@ class ReceiptsStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_max_receipt_stream_id, + current_token_without_instance(store.get_max_receipt_stream_id), db_query_to_update_function(store.get_all_updated_receipts), ) @@ -339,7 +358,7 @@ class PushRulesStream(Stream): hs.get_instance_name(), self._current_token, self._update_function ) - def _current_token(self) -> int: + def _current_token(self, instance_name: str) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token @@ -373,7 +392,7 @@ class PushersStream(Stream): super().__init__( hs.get_instance_name(), - store.get_pushers_stream_token, + current_token_without_instance(store.get_pushers_stream_token), db_query_to_update_function(store.get_all_updated_pushers_rows), ) @@ -402,12 +421,26 @@ class CachesStream(Stream): ROW_TYPE = CachesStreamRow def __init__(self, hs): - store = hs.get_datastore() + self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_cache_stream_token, - db_query_to_update_function(store.get_all_updated_caches), + self.store.get_cache_stream_token, + self._update_function, + ) + + async def _update_function( + self, instance_name: str, from_token: int, upto_token: int, limit: int + ): + rows = await self.store.get_all_updated_caches( + instance_name, from_token, upto_token, limit ) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited class PublicRoomsStream(Stream): @@ -431,7 +464,7 @@ class PublicRoomsStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_current_public_room_stream_id, + current_token_without_instance(store.get_current_public_room_stream_id), db_query_to_update_function(store.get_all_new_public_rooms), ) @@ -452,7 +485,7 @@ class DeviceListsStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_device_stream_token, + current_token_without_instance(store.get_device_stream_token), db_query_to_update_function(store.get_all_device_list_changes_for_remotes), ) @@ -470,7 +503,7 @@ class ToDeviceStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_to_device_stream_token, + current_token_without_instance(store.get_to_device_stream_token), db_query_to_update_function(store.get_all_new_device_messages), ) @@ -490,7 +523,7 @@ class TagAccountDataStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_max_account_data_stream_id, + current_token_without_instance(store.get_max_account_data_stream_id), db_query_to_update_function(store.get_all_updated_tags), ) @@ -510,7 +543,7 @@ class AccountDataStream(Stream): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self.store.get_max_account_data_stream_id, + current_token_without_instance(self.store.get_max_account_data_stream_id), db_query_to_update_function(self._update_function), ) @@ -541,7 +574,7 @@ class GroupServerStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_group_stream_token, + current_token_without_instance(store.get_group_stream_token), db_query_to_update_function(store.get_all_groups_changes), ) @@ -559,7 +592,7 @@ class UserSignatureStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_device_stream_token, + current_token_without_instance(store.get_device_stream_token), db_query_to_update_function( store.get_all_user_signature_changes_for_remotes ), diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 890e75d827..f370390331 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -20,7 +20,7 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token +from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance """Handling of the 'events' replication stream @@ -119,7 +119,7 @@ class EventsStream(Stream): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self._store.get_current_events_token, + current_token_without_instance(self._store.get_current_events_token), self._update_function, ) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index b0505b8a2c..9bcd13b009 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,7 +15,11 @@ # limitations under the License. from collections import namedtuple -from synapse.replication.tcp.streams._base import Stream, make_http_update_function +from synapse.replication.tcp.streams._base import ( + Stream, + current_token_without_instance, + make_http_update_function, +) class FederationStream(Stream): @@ -41,7 +45,9 @@ class FederationStream(Stream): # will be a real FederationSender, which has stubs for current_token and # get_replication_rows.) federation_sender = hs.get_federation_sender() - current_token = federation_sender.get_current_token + current_token = current_token_without_instance( + federation_sender.get_current_token + ) update_function = federation_sender.get_replication_rows elif hs.should_send_federation(): @@ -58,7 +64,7 @@ class FederationStream(Stream): super().__init__(hs.get_instance_name(), current_token, update_function) @staticmethod - def _stub_current_token(): + def _stub_current_token(instance_name: str) -> int: # dummy current-token method for use on workers return 0 |