From e8b68a4e4b439065536c281d8997af85880f6ee2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 14 Jan 2020 14:08:06 +0000 Subject: Fixup synapse.replication to pass mypy checks (#6667) --- synapse/replication/tcp/commands.py | 42 ++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 21 deletions(-) (limited to 'synapse/replication/tcp/commands.py') diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 0ff2a7199f..cbb36b9acf 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -20,15 +20,16 @@ allowed to be sent by which side. import logging import platform +from typing import Tuple, Type if platform.python_implementation() == "PyPy": import json _json_encoder = json.JSONEncoder() else: - import simplejson as json + import simplejson as json # type: ignore[no-redef] # noqa: F821 - _json_encoder = json.JSONEncoder(namedtuple_as_object=False) + _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821 logger = logging.getLogger(__name__) @@ -44,7 +45,7 @@ class Command(object): The default implementation creates a command of form ` ` """ - NAME = None + NAME = None # type: str def __init__(self, data): self.data = data @@ -386,25 +387,24 @@ class UserIpCommand(Command): ) +_COMMANDS = ( + ServerCommand, + RdataCommand, + PositionCommand, + ErrorCommand, + PingCommand, + NameCommand, + ReplicateCommand, + UserSyncCommand, + FederationAckCommand, + SyncCommand, + RemovePusherCommand, + InvalidateCacheCommand, + UserIpCommand, +) # type: Tuple[Type[Command], ...] + # Map of command name to command type. -COMMAND_MAP = { - cmd.NAME: cmd - for cmd in ( - ServerCommand, - RdataCommand, - PositionCommand, - ErrorCommand, - PingCommand, - NameCommand, - ReplicateCommand, - UserSyncCommand, - FederationAckCommand, - SyncCommand, - RemovePusherCommand, - InvalidateCacheCommand, - UserIpCommand, - ) -} +COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS} # The commands the server is allowed to send VALID_SERVER_COMMANDS = ( -- cgit 1.5.1 From a8a50f5b5746279379b4511c8ecb2a40b143fe32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 Jan 2020 10:27:19 +0000 Subject: Wake up transaction queue when remote server comes back online (#6706) This will be used to retry outbound transactions to a remote server if we think it might have come back up. --- changelog.d/6706.misc | 1 + docs/tcp_replication.md | 6 +++++- synapse/app/federation_sender.py | 12 +++++++++++- synapse/federation/sender/__init__.py | 18 ++++++++++++++++-- synapse/federation/transport/server.py | 19 ++++++++++++++++++- synapse/notifier.py | 31 ++++++++++++++++++++++++++++--- synapse/replication/tcp/client.py | 3 +++ synapse/replication/tcp/commands.py | 17 +++++++++++++++++ synapse/replication/tcp/protocol.py | 15 +++++++++++++++ synapse/replication/tcp/resource.py | 9 +++++++++ synapse/server.pyi | 12 ++++++++++++ 11 files changed, 135 insertions(+), 8 deletions(-) create mode 100644 changelog.d/6706.misc (limited to 'synapse/replication/tcp/commands.py') diff --git a/changelog.d/6706.misc b/changelog.d/6706.misc new file mode 100644 index 0000000000..1ac11cc04b --- /dev/null +++ b/changelog.d/6706.misc @@ -0,0 +1 @@ +Attempt to retry sending a transaction when we detect a remote server has come back online, rather than waiting for a transaction to be triggered by new data. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index ba9e874d07..a0b1d563ff 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -209,7 +209,7 @@ Where `` may be either: * a numeric stream_id to stream updates since (exclusive) * `NOW` to stream all subsequent updates. -The `` is the name of a replication stream to subscribe +The `` is the name of a replication stream to subscribe to (see [here](../synapse/replication/tcp/streams/_base.py) for a list of streams). It can also be `ALL` to subscribe to all known streams, in which case the `` must be set to `NOW`. @@ -234,6 +234,10 @@ in which case the `` must be set to `NOW`. Used exclusively in tests +### REMOTE_SERVER_UP (S, C) + + Inform other processes that a remote server may have come back online. + See `synapse/replication/tcp/commands.py` for a detailed description and the format of each command. diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index a57cf991ac..38d11fdd0f 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -158,6 +158,13 @@ class FederationSenderReplicationHandler(ReplicationClientHandler): args.update(self.send_handler.stream_positions()) return args + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + + # Let's wake up the transaction queue for the server in case we have + # pending stuff to send to it. + self.send_handler.wake_destination(server) + def start(config_options): try: @@ -205,7 +212,7 @@ class FederationSenderHandler(object): to the federation sender. """ - def __init__(self, hs, replication_client): + def __init__(self, hs: FederationSenderServer, replication_client): self.store = hs.get_datastore() self._is_mine_id = hs.is_mine_id self.federation_sender = hs.get_federation_sender() @@ -226,6 +233,9 @@ class FederationSenderHandler(object): self.store.get_room_max_stream_ordering() ) + def wake_destination(self, server: str): + self.federation_sender.wake_destination(server) + def stream_positions(self): return {"federation": self.federation_position} diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 4ebb0e8bc0..36c83c3027 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -21,6 +21,7 @@ from prometheus_client import Counter from twisted.internet import defer +import synapse import synapse.metrics from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.transaction_manager import TransactionManager @@ -54,7 +55,7 @@ sent_pdus_destination_dist_total = Counter( class FederationSender(object): - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.hs = hs self.server_name = hs.hostname @@ -482,7 +483,20 @@ class FederationSender(object): def send_device_messages(self, destination): if destination == self.server_name: - logger.info("Not sending device update to ourselves") + logger.warning("Not sending device update to ourselves") + return + + self._get_per_destination_queue(destination).attempt_new_transaction() + + def wake_destination(self, destination: str): + """Called when we want to retry sending transactions to a remote. + + This is mainly useful if the remote server has been down and we think it + might have come back. + """ + + if destination == self.server_name: + logger.warning("Not waking up ourselves") return self._get_per_destination_queue(destination).attempt_new_transaction() diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index b4cbf23394..d8cf9ed299 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -44,6 +44,7 @@ from synapse.logging.opentracing import ( tags, whitelisted_homeserver, ) +from synapse.server import HomeServer from synapse.types import ThirdPartyInstanceID, get_domain_from_id from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string @@ -101,12 +102,17 @@ class NoAuthenticationError(AuthenticationError): class Authenticator(object): - def __init__(self, hs): + def __init__(self, hs: HomeServer): self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname self.store = hs.get_datastore() self.federation_domain_whitelist = hs.config.federation_domain_whitelist + self.notifer = hs.get_notifier() + + self.replication_client = None + if hs.config.worker.worker_app: + self.replication_client = hs.get_tcp_replication() # A method just so we can pass 'self' as the authenticator to the Servlets async def authenticate_request(self, request, content): @@ -166,6 +172,17 @@ class Authenticator(object): try: logger.info("Marking origin %r as up", origin) await self.store.set_destination_retry_timings(origin, None, 0, 0) + + # Inform the relevant places that the remote server is back up. + self.notifer.notify_remote_server_up(origin) + if self.replication_client: + # If we're on a worker we try and inform master about this. The + # replication client doesn't hook into the notifier to avoid + # infinite loops where we send a `REMOTE_SERVER_UP` command to + # master, which then echoes it back to us which in turn pokes + # the notifier. + self.replication_client.send_remote_server_up(origin) + except Exception: logger.exception("Error resetting retry timings on %s", origin) diff --git a/synapse/notifier.py b/synapse/notifier.py index 5f5f765bea..6132727cbd 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -15,11 +15,13 @@ import logging from collections import namedtuple +from typing import Callable, List from prometheus_client import Counter from twisted.internet import defer +import synapse.server from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state @@ -154,7 +156,7 @@ class Notifier(object): UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.user_to_user_stream = {} self.room_to_user_streams = {} @@ -164,7 +166,12 @@ class Notifier(object): self.store = hs.get_datastore() self.pending_new_room_events = [] - self.replication_callbacks = [] + # Called when there are new things to stream over replication + self.replication_callbacks = [] # type: List[Callable[[], None]] + + # Called when remote servers have come back online after having been + # down. + self.remote_server_up_callbacks = [] # type: List[Callable[[str], None]] self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() @@ -205,7 +212,7 @@ class Notifier(object): "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream) ) - def add_replication_callback(self, cb): + def add_replication_callback(self, cb: Callable[[], None]): """Add a callback that will be called when some new data is available. Callback is not given any arguments. It should *not* return a Deferred - if it needs to do any asynchronous work, a background thread should be started and @@ -213,6 +220,12 @@ class Notifier(object): """ self.replication_callbacks.append(cb) + def add_remote_server_up_callback(self, cb: Callable[[str], None]): + """Add a callback that will be called when synapse detects a server + has been + """ + self.remote_server_up_callbacks.append(cb) + def on_new_room_event( self, event, room_stream_id, max_room_stream_id, extra_users=[] ): @@ -522,3 +535,15 @@ class Notifier(object): """Notify the any replication listeners that there's a new event""" for cb in self.replication_callbacks: cb() + + def notify_remote_server_up(self, server: str): + """Notify any replication that a remote server has come back up + """ + # We call federation_sender directly rather than registering as a + # callback as a) we already have a reference to it and b) it introduces + # circular dependencies. + if self.federation_sender: + self.federation_sender.wake_destination(server) + + for cb in self.remote_server_up_callbacks: + cb(server) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 52a0aefe68..fc06a7b053 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -143,6 +143,9 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): if d: d.callback(data) + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + def get_streams_to_replicate(self) -> Dict[str, int]: """Called when a new connection has been established and we need to subscribe to streams. diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index cbb36b9acf..451671412d 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -387,6 +387,20 @@ class UserIpCommand(Command): ) +class RemoteServerUpCommand(Command): + """Sent when a worker has detected that a remote server is no longer + "down" and retry timings should be reset. + + If sent from a client the server will relay to all other workers. + + Format:: + + REMOTE_SERVER_UP + """ + + NAME = "REMOTE_SERVER_UP" + + _COMMANDS = ( ServerCommand, RdataCommand, @@ -401,6 +415,7 @@ _COMMANDS = ( RemovePusherCommand, InvalidateCacheCommand, UserIpCommand, + RemoteServerUpCommand, ) # type: Tuple[Type[Command], ...] # Map of command name to command type. @@ -414,6 +429,7 @@ VALID_SERVER_COMMANDS = ( ErrorCommand.NAME, PingCommand.NAME, SyncCommand.NAME, + RemoteServerUpCommand.NAME, ) # The commands the client is allowed to send @@ -427,4 +443,5 @@ VALID_CLIENT_COMMANDS = ( InvalidateCacheCommand.NAME, UserIpCommand.NAME, ErrorCommand.NAME, + RemoteServerUpCommand.NAME, ) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 5f4bdf84d2..131e5acb09 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -76,6 +76,7 @@ from synapse.replication.tcp.commands import ( PingCommand, PositionCommand, RdataCommand, + RemoteServerUpCommand, ReplicateCommand, ServerCommand, SyncCommand, @@ -460,6 +461,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): async def on_INVALIDATE_CACHE(self, cmd): self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + self.streamer.on_remote_server_up(cmd.data) + async def on_USER_IP(self, cmd): self.streamer.on_user_ip( cmd.user_id, @@ -555,6 +559,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): def send_sync(self, data): self.send_command(SyncCommand(data)) + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) + def on_connection_closed(self): BaseReplicationStreamProtocol.on_connection_closed(self) self.streamer.lost_connection(self) @@ -588,6 +595,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): """Called when get a new SYNC command.""" raise NotImplementedError() + @abc.abstractmethod + async def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + raise NotImplementedError() + @abc.abstractmethod def get_streams_to_replicate(self): """Called when a new connection has been established and we need to @@ -707,6 +719,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): async def on_SYNC(self, cmd): self.handler.on_sync(cmd.data) + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + self.handler.on_remote_server_up(cmd.data) + def replicate(self, stream_name, token): """Send the subscription request to the server """ diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index b1752e88cd..6ebf944f66 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -120,6 +120,7 @@ class ReplicationStreamer(object): self.federation_sender = hs.get_federation_sender() self.notifier.add_replication_callback(self.on_notifier_poke) + self.notifier.add_remote_server_up_callback(self.send_remote_server_up) # Keeps track of whether we are currently checking for updates self.is_looping = False @@ -288,6 +289,14 @@ class ReplicationStreamer(object): ) await self._server_notices_sender.on_user_ip(user_id) + @measure_func("repl.on_remote_server_up") + def on_remote_server_up(self, server: str): + self.notifier.notify_remote_server_up(server) + + def send_remote_server_up(self, server: str): + for conn in self.connections: + conn.send_remote_server_up(server) + def send_sync_to_all_connections(self, data): """Sends a SYNC command to all clients. diff --git a/synapse/server.pyi b/synapse/server.pyi index b5e0b57095..0731403047 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,3 +1,5 @@ +import twisted.internet + import synapse.api.auth import synapse.config.homeserver import synapse.federation.sender @@ -9,10 +11,12 @@ import synapse.handlers.deactivate_account import synapse.handlers.device import synapse.handlers.e2e_keys import synapse.handlers.message +import synapse.handlers.presence import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password import synapse.http.client +import synapse.notifier import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender @@ -85,3 +89,11 @@ class HomeServer(object): self, ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: pass + def get_notifier(self) -> synapse.notifier.Notifier: + pass + def get_presence_handler(self) -> synapse.handlers.presence.PresenceHandler: + pass + def get_clock(self) -> synapse.util.Clock: + pass + def get_reactor(self) -> twisted.internet.base.ReactorBase: + pass -- cgit 1.5.1 From 4cff617df1ba6f241fee6957cc44859f57edcc0e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 25 Mar 2020 14:54:01 +0000 Subject: Move catchup of replication streams to worker. (#7024) This changes the replication protocol so that the server does not send down `RDATA` for rows that happened before the client connected. Instead, the server will send a `POSITION` and clients then query the database (or master out of band) to get up to date. --- changelog.d/7024.misc | 1 + docs/tcp_replication.md | 46 ++--- synapse/app/generic_worker.py | 3 + synapse/federation/sender/__init__.py | 9 + synapse/replication/http/__init__.py | 2 + synapse/replication/http/streams.py | 78 ++++++++ synapse/replication/slave/storage/_base.py | 14 +- synapse/replication/slave/storage/pushers.py | 3 + synapse/replication/tcp/client.py | 3 +- synapse/replication/tcp/commands.py | 34 +--- synapse/replication/tcp/protocol.py | 206 ++++++++-------------- synapse/replication/tcp/resource.py | 19 +- synapse/replication/tcp/streams/__init__.py | 8 +- synapse/replication/tcp/streams/_base.py | 160 +++++++++++------ synapse/replication/tcp/streams/events.py | 5 +- synapse/replication/tcp/streams/federation.py | 19 +- synapse/server.py | 5 + synapse/storage/data_stores/main/cache.py | 44 ++--- synapse/storage/data_stores/main/deviceinbox.py | 88 ++++----- synapse/storage/data_stores/main/events.py | 114 ------------ synapse/storage/data_stores/main/events_worker.py | 114 ++++++++++++ synapse/storage/data_stores/main/room.py | 40 ++--- tests/replication/tcp/streams/_base.py | 55 ++++-- tests/replication/tcp/streams/test_receipts.py | 52 +++++- 24 files changed, 635 insertions(+), 487 deletions(-) create mode 100644 changelog.d/7024.misc create mode 100644 synapse/replication/http/streams.py (limited to 'synapse/replication/tcp/commands.py') diff --git a/changelog.d/7024.misc b/changelog.d/7024.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7024.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index e3a4634b14..d4f7d9ec18 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and '<' worker to master flows): > SERVER example.com - < REPLICATE events 53 + < REPLICATE + > POSITION events 53 > RDATA events 54 ["$foo1:bar.com", ...] > RDATA events 55 ["$foo4:bar.com", ...] -The example shows the server accepting a new connection and sending its -identity with the `SERVER` command, followed by the client asking to -subscribe to the `events` stream from the token `53`. The server then -periodically sends `RDATA` commands which have the format -`RDATA `, where the format of `` is -defined by the individual streams. +The example shows the server accepting a new connection and sending its identity +with the `SERVER` command, followed by the client server to respond with the +position of all streams. The server then periodically sends `RDATA` commands +which have the format `RDATA `, where the format of +`` is defined by the individual streams. Error reporting happens by either the client or server sending an ERROR command, and usually the connection will be closed. @@ -32,9 +32,6 @@ Since the protocol is a simple line based, its possible to manually connect to the server using a tool like netcat. A few things should be noted when manually using the protocol: -- When subscribing to a stream using `REPLICATE`, the special token - `NOW` can be used to get all future updates. The special stream name - `ALL` can be used with `NOW` to subscribe to all available streams. - The federation stream is only available if federation sending has been disabled on the main process. - The server will only time connections out that have sent a `PING` @@ -91,9 +88,7 @@ The client: - Sends a `NAME` command, allowing the server to associate a human friendly name with the connection. This is optional. - Sends a `PING` as above -- For each stream the client wishes to subscribe to it sends a - `REPLICATE` with the `stream_name` and token it wants to subscribe - from. +- Sends a `REPLICATE` to get the current position of all streams. - On receipt of a `SERVER` command, checks that the server name matches the expected server name. @@ -140,9 +135,7 @@ the wire: > PING 1490197665618 < NAME synapse.app.appservice < PING 1490197665618 - < REPLICATE events 1 - < REPLICATE backfill 1 - < REPLICATE caches 1 + < REPLICATE > POSITION events 1 > POSITION backfill 1 > POSITION caches 1 @@ -181,9 +174,9 @@ client (C): #### POSITION (S) - The position of the stream has been updated. Sent to the client - after all missing updates for a stream have been sent to the client - and they're now up to date. + On receipt of a POSITION command clients should check if they have missed any + updates, and if so then fetch them out of band. Sent in response to a + REPLICATE command (but can happen at any time). #### ERROR (S, C) @@ -199,20 +192,7 @@ client (C): #### REPLICATE (C) -Asks the server to replicate a given stream. The syntax is: - -``` - REPLICATE -``` - -Where `` may be either: - * a numeric stream_id to stream updates since (exclusive) - * `NOW` to stream all subsequent updates. - -The `` is the name of a replication stream to subscribe -to (see [here](../synapse/replication/tcp/streams/_base.py) for a list -of streams). It can also be `ALL` to subscribe to all known streams, -in which case the `` must be set to `NOW`. +Asks the server for the current position of all streams. #### USER_SYNC (C) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index bd1733573b..fba7ad9551 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -401,6 +401,9 @@ class GenericWorkerTyping(object): self._room_serials[row.room_id] = token self._room_typing[row.room_id] = row.user_ids + def get_current_token(self) -> int: + return self._latest_room_serial + class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 233cb33daf..a477578e44 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -499,4 +499,13 @@ class FederationSender(object): self._get_per_destination_queue(destination).attempt_new_transaction() def get_current_token(self) -> int: + # Dummy implementation for case where federation sender isn't offloaded + # to a worker. return 0 + + async def get_replication_rows( + self, from_token, to_token, limit, federation_ack=None + ): + # Dummy implementation for case where federation sender isn't offloaded + # to a worker. + return [] diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 28dbc6fcba..4613b2538c 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -21,6 +21,7 @@ from synapse.replication.http import ( membership, register, send_event, + streams, ) REPLICATION_PREFIX = "/_synapse/replication" @@ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource): login.register_servlets(hs, self) register.register_servlets(hs, self) devices.register_servlets(hs, self) + streams.register_servlets(hs, self) diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py new file mode 100644 index 0000000000..ffd4c61993 --- /dev/null +++ b/synapse/replication/http/streams.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import SynapseError +from synapse.http.servlet import parse_integer +from synapse.replication.http._base import ReplicationEndpoint + +logger = logging.getLogger(__name__) + + +class ReplicationGetStreamUpdates(ReplicationEndpoint): + """Fetches stream updates from a server. Used for streams not persisted to + the database, e.g. typing notifications. + + The API looks like: + + GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100 + + 200 OK + + { + updates: [ ... ], + upto_token: 10, + limited: False, + } + + """ + + NAME = "get_repl_stream_updates" + PATH_ARGS = ("stream_name",) + METHOD = "GET" + + def __init__(self, hs): + super().__init__(hs) + + # We pull the streams from the replication steamer (if we try and make + # them ourselves we end up in an import loop). + self.streams = hs.get_replication_streamer().get_streams() + + @staticmethod + def _serialize_payload(stream_name, from_token, upto_token, limit): + return {"from_token": from_token, "upto_token": upto_token, "limit": limit} + + async def _handle_request(self, request, stream_name): + stream = self.streams.get(stream_name) + if stream is None: + raise SynapseError(400, "Unknown stream") + + from_token = parse_integer(request, "from_token", required=True) + upto_token = parse_integer(request, "upto_token", required=True) + limit = parse_integer(request, "limit", required=True) + + updates, upto_token, limited = await stream.get_updates_since( + from_token, upto_token, limit + ) + + return ( + 200, + {"updates": updates, "upto_token": upto_token, "limited": limited}, + ) + + +def register_servlets(hs, http_server): + ReplicationGetStreamUpdates(hs).register(http_server) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index f45cbd37a0..751c799d94 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -18,8 +18,10 @@ from typing import Dict, Optional import six -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME +from synapse.storage.data_stores.main.cache import ( + CURRENT_STATE_CACHE_NAME, + CacheInvalidationWorkerStore, +) from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine @@ -35,7 +37,7 @@ def __func__(inp): return inp.__func__ -class BaseSlavedStore(SQLBaseStore): +class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: Database, db_conn, hs): super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): @@ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore): pos["caches"] = self._cache_id_gen.get_current_token() return pos + def get_cache_stream_token(self): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token() + else: + return 0 + def process_replication_rows(self, stream_name, token, rows): if stream_name == "caches": if self._cache_id_gen: diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index f22c2d44a3..bce8a3d115 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): result["pushers"] = self._pushers_id_gen.get_current_token() return result + def get_pushers_stream_token(self): + return self._pushers_id_gen.get_current_token() + def process_replication_rows(self, stream_name, token, rows): if stream_name == "pushers": self._pushers_id_gen.advance(token) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 02ab5b66ea..7e7ad0f798 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): self.client_name = client_name self.handler = handler self.server_name = hs.config.server_name + self.hs = hs self._clock = hs.get_clock() # As self.clock is defined in super class hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) @@ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): def buildProtocol(self, addr): logger.info("Connected to replication: %r", addr) return ClientReplicationStreamProtocol( - self.client_name, self.server_name, self._clock, self.handler + self.hs, self.client_name, self.server_name, self._clock, self.handler, ) def clientConnectionLost(self, connector, reason): diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 451671412d..5a6b734094 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -136,8 +136,8 @@ class PositionCommand(Command): """Sent by the server to tell the client the stream postition without needing to send an RDATA. - Sent to the client after all missing updates for a stream have been sent - to the client and they're now up to date. + On receipt of a POSITION command clients should check if they have missed + any updates, and if so then fetch them out of band. """ NAME = "POSITION" @@ -179,42 +179,24 @@ class NameCommand(Command): class ReplicateCommand(Command): - """Sent by the client to subscribe to the stream. + """Sent by the client to subscribe to streams. Format:: - REPLICATE - - Where may be either: - * a numeric stream_id to stream updates from - * "NOW" to stream all subsequent updates. - - The can be "ALL" to subscribe to all known streams, in which - case the must be set to "NOW", i.e.:: - - REPLICATE ALL NOW + REPLICATE """ NAME = "REPLICATE" - def __init__(self, stream_name, token): - self.stream_name = stream_name - self.token = token + def __init__(self): + pass @classmethod def from_line(cls, line): - stream_name, token = line.split(" ", 1) - if token in ("NOW", "now"): - token = "NOW" - else: - token = int(token) - return cls(stream_name, token) + return cls() def to_line(self): - return " ".join((self.stream_name, str(self.token))) - - def get_logcontext_id(self): - return "REPLICATE-" + self.stream_name + return "" class UserSyncCommand(Command): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index bc1482a9bb..f81d2e2442 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire:: > PING 1490197665618 < NAME synapse.app.appservice < PING 1490197665618 - < REPLICATE events 1 - < REPLICATE backfill 1 - < REPLICATE caches 1 + < REPLICATE > POSITION events 1 > POSITION backfill 1 > POSITION caches 1 @@ -53,17 +51,15 @@ import fcntl import logging import struct from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Set, Tuple +from typing import Any, DefaultDict, Dict, List, Set -from six import iteritems, iterkeys +from six import iteritems from prometheus_client import Counter -from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure -from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( @@ -82,11 +78,16 @@ from synapse.replication.tcp.commands import ( SyncCommand, UserSyncCommand, ) -from synapse.replication.tcp.streams import STREAMS_MAP +from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string +MYPY = False +if MYPY: + from synapse.server import HomeServer + + connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) @@ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.server_name = server_name self.streamer = streamer - # The streams the client has subscribed to and is up to date with - self.replication_streams = set() # type: Set[str] - - # The streams the client is currently subscribing to. - self.connecting_streams = set() # type: Set[str] - - # Map from stream name to list of updates to send once we've finished - # subscribing the client to the stream. - self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]] - def connectionMade(self): self.send_command(ServerCommand(self.server_name)) BaseReplicationStreamProtocol.connectionMade(self) @@ -436,21 +427,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): ) async def on_REPLICATE(self, cmd): - stream_name = cmd.stream_name - token = cmd.token - - if stream_name == "ALL": - # Subscribe to all streams we're publishing to. - deferreds = [ - run_in_background(self.subscribe_to_stream, stream, token) - for stream in iterkeys(self.streamer.streams_by_name) - ] - - await make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True) - ) - else: - await self.subscribe_to_stream(stream_name, token) + # Subscribe to all streams we're publishing to. + for stream_name in self.streamer.streams_by_name: + current_token = self.streamer.get_stream_token(stream_name) + self.send_command(PositionCommand(stream_name, current_token)) async def on_FEDERATION_ACK(self, cmd): self.streamer.federation_ack(cmd.token) @@ -474,87 +454,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): cmd.last_seen, ) - async def subscribe_to_stream(self, stream_name, token): - """Subscribe the remote to a stream. - - This invloves checking if they've missed anything and sending those - updates down if they have. During that time new updates for the stream - are queued and sent once we've sent down any missed updates. - """ - self.replication_streams.discard(stream_name) - self.connecting_streams.add(stream_name) - - try: - # Get missing updates - updates, current_token = await self.streamer.get_stream_updates( - stream_name, token - ) - - # Send all the missing updates - for update in updates: - token, row = update[0], update[1] - self.send_command(RdataCommand(stream_name, token, row)) - - # We send a POSITION command to ensure that they have an up to - # date token (especially useful if we didn't send any updates - # above) - self.send_command(PositionCommand(stream_name, current_token)) - - # Now we can send any updates that came in while we were subscribing - pending_rdata = self.pending_rdata.pop(stream_name, []) - updates = [] - for token, update in pending_rdata: - # If the token is null, it is part of a batch update. Batches - # are multiple updates that share a single token. To denote - # this, the token is set to None for all tokens in the batch - # except for the last. If we find a None token, we keep looking - # through tokens until we find one that is not None and then - # process all previous updates in the batch as if they had the - # final token. - if token is None: - # Store this update as part of a batch - updates.append(update) - continue - - if token <= current_token: - # This update or batch of updates is older than - # current_token, dismiss it - updates = [] - continue - - updates.append(update) - - # Send all updates that are part of this batch with the - # found token - for update in updates: - self.send_command(RdataCommand(stream_name, token, update)) - - # Clear stored updates - updates = [] - - # They're now fully subscribed - self.replication_streams.add(stream_name) - except Exception as e: - logger.exception("[%s] Failed to handle REPLICATE command", self.id()) - self.send_error("failed to handle replicate: %r", e) - finally: - self.connecting_streams.discard(stream_name) - def stream_update(self, stream_name, token, data): """Called when a new update is available to stream to clients. We need to check if the client is interested in the stream or not """ - if stream_name in self.replication_streams: - # The client is subscribed to the stream - self.send_command(RdataCommand(stream_name, token, data)) - elif stream_name in self.connecting_streams: - # The client is being subscribed to the stream - logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token) - self.pending_rdata.setdefault(stream_name, []).append((token, data)) - else: - # The client isn't subscribed - logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) + self.send_command(RdataCommand(stream_name, token, data)) def send_sync(self, data): self.send_command(SyncCommand(data)) @@ -638,6 +543,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): def __init__( self, + hs: "HomeServer", client_name: str, server_name: str, clock: Clock, @@ -649,22 +555,25 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.server_name = server_name self.handler = handler + self.streams = { + stream.NAME: stream(hs) for stream in STREAMS_MAP.values() + } # type: Dict[str, Stream] + # Set of stream names that have been subscribe to, but haven't yet # caught up with. This is used to track when the client has been fully # connected to the remote. - self.streams_connecting = set() # type: Set[str] + self.streams_connecting = set(STREAMS_MAP) # type: Set[str] # Map of stream to batched updates. See RdataCommand for info on how # batching works. - self.pending_batches = {} # type: Dict[str, Any] + self.pending_batches = {} # type: Dict[str, List[Any]] def connectionMade(self): self.send_command(NameCommand(self.client_name)) BaseReplicationStreamProtocol.connectionMade(self) # Once we've connected subscribe to the necessary streams - for stream_name, token in iteritems(self.handler.get_streams_to_replicate()): - self.replicate(stream_name, token) + self.replicate() # Tell the server if we have any users currently syncing (should only # happen on synchrotrons) @@ -676,10 +585,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # We've now finished connecting to so inform the client handler self.handler.update_connection(self) - # This will happen if we don't actually subscribe to any streams - if not self.streams_connecting: - self.handler.finished_connecting() - async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) @@ -697,7 +602,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): ) raise - if cmd.token is None: + if cmd.token is None or stream_name in self.streams_connecting: # I.e. this is part of a batch of updates for this stream. Batch # until we get an update for the stream with a non None token self.pending_batches.setdefault(stream_name, []).append(row) @@ -707,14 +612,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): rows.append(row) await self.handler.on_rdata(stream_name, cmd.token, rows) - async def on_POSITION(self, cmd): - # When we get a `POSITION` command it means we've finished getting - # missing updates for the given stream, and are now up to date. + async def on_POSITION(self, cmd: PositionCommand): + stream = self.streams.get(cmd.stream_name) + if not stream: + logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + return + + # Find where we previously streamed up to. + current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) + if current_token is None: + logger.warning( + "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name + ) + return + + # Fetch all updates between then and now. + limited = True + while limited: + updates, current_token, limited = await stream.get_updates_since( + current_token, cmd.token + ) + + # Check if the connection was closed underneath us, if so we bail + # rather than risk having concurrent catch ups going on. + if self.state == ConnectionStates.CLOSED: + return + + if updates: + await self.handler.on_rdata( + cmd.stream_name, + current_token, + [stream.parse_row(update[1]) for update in updates], + ) + + # We've now caught up to position sent to us, notify handler. + await self.handler.on_position(cmd.stream_name, cmd.token) + self.streams_connecting.discard(cmd.stream_name) if not self.streams_connecting: self.handler.finished_connecting() - await self.handler.on_position(cmd.stream_name, cmd.token) + # Check if the connection was closed underneath us, if so we bail + # rather than risk having concurrent catch ups going on. + if self.state == ConnectionStates.CLOSED: + return + + # Handle any RDATA that came in while we were catching up. + rows = self.pending_batches.pop(cmd.stream_name, []) + if rows: + await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) async def on_SYNC(self, cmd): self.handler.on_sync(cmd.data) @@ -722,22 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): self.handler.on_remote_server_up(cmd.data) - def replicate(self, stream_name, token): + def replicate(self): """Send the subscription request to the server """ - if stream_name not in STREAMS_MAP: - raise Exception("Invalid stream name %r" % (stream_name,)) - - logger.info( - "[%s] Subscribing to replication stream: %r from %r", - self.id(), - stream_name, - token, - ) - - self.streams_connecting.add(stream_name) + logger.info("[%s] Subscribing to replication streams", self.id()) - self.send_command(ReplicateCommand(stream_name, token)) + self.send_command(ReplicateCommand()) def on_connection_closed(self): BaseReplicationStreamProtocol.on_connection_closed(self) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 6e2ebaf614..4374e99e32 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -17,7 +17,7 @@ import logging import random -from typing import Any, List +from typing import Any, Dict, List from six import itervalues @@ -30,7 +30,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.metrics import Measure, measure_func from .protocol import ServerReplicationStreamProtocol -from .streams import STREAMS_MAP +from .streams import STREAMS_MAP, Stream from .streams.federation import FederationStream stream_updates_counter = Counter( @@ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory): """ def __init__(self, hs): - self.streamer = ReplicationStreamer(hs) + self.streamer = hs.get_replication_streamer() self.clock = hs.get_clock() self.server_name = hs.config.server_name @@ -133,6 +133,11 @@ class ReplicationStreamer(object): for conn in self.connections: conn.send_error("server shutting down") + def get_streams(self) -> Dict[str, Stream]: + """Get a mapp from stream name to stream instance. + """ + return self.streams_by_name + def on_notifier_poke(self): """Checks if there is actually any new data and sends it to the connections if there are. @@ -190,7 +195,8 @@ class ReplicationStreamer(object): stream.current_token(), ) try: - updates, current_token = await stream.get_updates() + updates, current_token, limited = await stream.get_updates() + self.pending_updates |= limited except Exception: logger.info("Failed to handle stream %s", stream.NAME) raise @@ -226,8 +232,7 @@ class ReplicationStreamer(object): self.pending_updates = False self.is_looping = False - @measure_func("repl.get_stream_updates") - async def get_stream_updates(self, stream_name, token): + def get_stream_token(self, stream_name): """For a given stream get all updates since token. This is called when a client first subscribes to a stream. """ @@ -235,7 +240,7 @@ class ReplicationStreamer(object): if not stream: raise Exception("unknown stream %s", stream_name) - return await stream.get_updates_since(token) + return stream.current_token() @measure_func("repl.federation_ack") def federation_ack(self, token): diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 29199f5b46..37bcd3de66 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -24,6 +24,9 @@ Each stream is defined by the following information: current_token: The function that returns the current token for the stream update_function: The function that returns a list of updates between two tokens """ + +from typing import Dict, Type + from synapse.replication.tcp.streams._base import ( AccountDataStream, BackfillStream, @@ -35,6 +38,7 @@ from synapse.replication.tcp.streams._base import ( PushersStream, PushRulesStream, ReceiptsStream, + Stream, TagAccountDataStream, ToDeviceStream, TypingStream, @@ -63,10 +67,12 @@ STREAMS_MAP = { GroupServerStream, UserSignatureStream, ) -} +} # type: Dict[str, Type[Stream]] + __all__ = [ "STREAMS_MAP", + "Stream", "BackfillStream", "PresenceStream", "TypingStream", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 32d9514883..c14dff6c64 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging from collections import namedtuple -from typing import Any, List, Optional, Tuple +from typing import Any, Awaitable, Callable, List, Optional, Tuple import attr +from synapse.replication.http.streams import ReplicationGetStreamUpdates from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -29,6 +29,15 @@ logger = logging.getLogger(__name__) MAX_EVENTS_BEHIND = 500000 +# Some type aliases to make things a bit easier. + +# A stream position token +Token = int + +# A pair of position in stream and args used to create an instance of `ROW_TYPE`. +StreamRow = Tuple[Token, tuple] + + class Stream(object): """Base class for the streams. @@ -56,6 +65,7 @@ class Stream(object): return cls.ROW_TYPE(*row) def __init__(self, hs): + # The token from which we last asked for updates self.last_token = self.current_token() @@ -65,61 +75,46 @@ class Stream(object): """ self.last_token = self.current_token() - async def get_updates(self): + async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: """Gets all updates since the last time this function was called (or since the stream was constructed if it hadn't been called before). Returns: - Deferred[Tuple[List[Tuple[int, Any]], int]: - Resolves to a pair ``(updates, current_token)``, where ``updates`` is a - list of ``(token, row)`` entries. ``row`` will be json-serialised and - sent over the replication steam. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ - updates, current_token = await self.get_updates_since(self.last_token) + current_token = self.current_token() + updates, current_token, limited = await self.get_updates_since( + self.last_token, current_token + ) self.last_token = current_token - return updates, current_token + return updates, current_token, limited async def get_updates_since( - self, from_token: int - ) -> Tuple[List[Tuple[int, JsonDict]], int]: + self, from_token: Token, upto_token: Token, limit: int = 100 + ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: """Like get_updates except allows specifying from when we should stream updates Returns: - Resolves to a pair `(updates, new_last_token)`, where `updates` is - a list of `(token, row)` entries and `new_last_token` is the new - position in stream. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ - if from_token in ("NOW", "now"): - return [], self.current_token() - - current_token = self.current_token() - from_token = int(from_token) - if from_token == current_token: - return [], current_token + if from_token == upto_token: + return [], upto_token, False - rows = await self.update_function( - from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 + updates, upto_token, limited = await self.update_function( + from_token, upto_token, limit=limit, ) - - # never turn more than MAX_EVENTS_BEHIND + 1 into updates. - rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) - - updates = [(row[0], row[1:]) for row in rows] - - # check we didn't get more rows than the limit. - # doing it like this allows the update_function to be a generator. - if len(updates) >= MAX_EVENTS_BEHIND: - raise Exception("stream %s has fallen behind" % (self.NAME)) - - # The update function didn't hit the limit, so we must have got all - # the updates to `current_token`, and can return that as our new - # stream position. - return updates, current_token + return updates, upto_token, limited def current_token(self): """Gets the current token of the underlying streams. Should be provided @@ -141,6 +136,48 @@ class Stream(object): raise NotImplementedError() +def db_query_to_update_function( + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] +) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]: + """Wraps a db query function which returns a list of rows to make it + suitable for use as an `update_function` for the Stream class + """ + + async def update_function(from_token, upto_token, limit): + rows = await query_function(from_token, upto_token, limit) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) == limit: + upto_token = rows[-1][0] + limited = True + + return updates, upto_token, limited + + return update_function + + +def make_http_update_function( + hs, stream_name: str +) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]: + """Makes a suitable function for use as an `update_function` that queries + the master process for updates. + """ + + client = ReplicationGetStreamUpdates.make_client(hs) + + async def update_function( + from_token: int, upto_token: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + return await client( + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, + limit=limit, + ) + + return update_function + + class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. @@ -164,7 +201,7 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_current_backfill_token # type: ignore - self.update_function = store.get_all_new_backfill_event_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore super(BackfillStream, self).__init__(hs) @@ -190,8 +227,15 @@ class PresenceStream(Stream): store = hs.get_datastore() presence_handler = hs.get_presence_handler() + self._is_worker = hs.config.worker_app is not None + self.current_token = store.get_current_presence_token # type: ignore - self.update_function = presence_handler.get_all_presence_updates # type: ignore + + if hs.config.worker_app is None: + self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore + else: + # Query master process + self.update_function = make_http_update_function(hs, self.NAME) # type: ignore super(PresenceStream, self).__init__(hs) @@ -208,7 +252,12 @@ class TypingStream(Stream): typing_handler = hs.get_typing_handler() self.current_token = typing_handler.get_current_token # type: ignore - self.update_function = typing_handler.get_all_typing_updates # type: ignore + + if hs.config.worker_app is None: + self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore + else: + # Query master process + self.update_function = make_http_update_function(hs, self.NAME) # type: ignore super(TypingStream, self).__init__(hs) @@ -232,7 +281,7 @@ class ReceiptsStream(Stream): store = hs.get_datastore() self.current_token = store.get_max_receipt_stream_id # type: ignore - self.update_function = store.get_all_updated_receipts # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore super(ReceiptsStream, self).__init__(hs) @@ -256,7 +305,13 @@ class PushRulesStream(Stream): async def update_function(self, from_token, to_token, limit): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) - return [(row[0], row[2]) for row in rows] + + limited = False + if len(rows) == limit: + to_token = rows[-1][0] + limited = True + + return [(row[0], (row[2],)) for row in rows], to_token, limited class PushersStream(Stream): @@ -275,7 +330,7 @@ class PushersStream(Stream): store = hs.get_datastore() self.current_token = store.get_pushers_stream_token # type: ignore - self.update_function = store.get_all_updated_pushers_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore super(PushersStream, self).__init__(hs) @@ -307,7 +362,7 @@ class CachesStream(Stream): store = hs.get_datastore() self.current_token = store.get_cache_stream_token # type: ignore - self.update_function = store.get_all_updated_caches # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore super(CachesStream, self).__init__(hs) @@ -333,7 +388,7 @@ class PublicRoomsStream(Stream): store = hs.get_datastore() self.current_token = store.get_current_public_room_stream_id # type: ignore - self.update_function = store.get_all_new_public_rooms # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore super(PublicRoomsStream, self).__init__(hs) @@ -354,7 +409,7 @@ class DeviceListsStream(Stream): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore super(DeviceListsStream, self).__init__(hs) @@ -372,7 +427,7 @@ class ToDeviceStream(Stream): store = hs.get_datastore() self.current_token = store.get_to_device_stream_token # type: ignore - self.update_function = store.get_all_new_device_messages # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore super(ToDeviceStream, self).__init__(hs) @@ -392,7 +447,7 @@ class TagAccountDataStream(Stream): store = hs.get_datastore() self.current_token = store.get_max_account_data_stream_id # type: ignore - self.update_function = store.get_all_updated_tags # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore super(TagAccountDataStream, self).__init__(hs) @@ -412,10 +467,11 @@ class AccountDataStream(Stream): self.store = hs.get_datastore() self.current_token = self.store.get_max_account_data_stream_id # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(AccountDataStream, self).__init__(hs) - async def update_function(self, from_token, to_token, limit): + async def _update_function(self, from_token, to_token, limit): global_results, room_results = await self.store.get_all_updated_account_data( from_token, from_token, to_token, limit ) @@ -442,7 +498,7 @@ class GroupServerStream(Stream): store = hs.get_datastore() self.current_token = store.get_group_stream_token # type: ignore - self.update_function = store.get_all_groups_changes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore super(GroupServerStream, self).__init__(hs) @@ -460,6 +516,6 @@ class UserSignatureStream(Stream): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore super(UserSignatureStream, self).__init__(hs) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index b3afabb8cd..c6a595629f 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ from typing import Tuple, Type import attr -from ._base import Stream +from ._base import Stream, db_query_to_update_function """Handling of the 'events' replication stream @@ -117,10 +117,11 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() self.current_token = self._store.get_current_events_token # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(EventsStream, self).__init__(hs) - async def update_function(self, from_token, current_token, limit=None): + async def _update_function(self, from_token, current_token, limit=None): event_rows = await self._store.get_all_new_forward_event_rows( from_token, current_token, limit ) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index f5f9336430..48c1d45718 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,7 +15,9 @@ # limitations under the License. from collections import namedtuple -from ._base import Stream +from twisted.internet import defer + +from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function class FederationStream(Stream): @@ -33,11 +35,18 @@ class FederationStream(Stream): NAME = "federation" ROW_TYPE = FederationStreamRow + _QUERY_MASTER = True def __init__(self, hs): - federation_sender = hs.get_federation_sender() - - self.current_token = federation_sender.get_current_token # type: ignore - self.update_function = federation_sender.get_replication_rows # type: ignore + # Not all synapse instances will have a federation sender instance, + # whether that's a `FederationSender` or a `FederationRemoteSendQueue`, + # so we stub the stream out when that is the case. + if hs.config.worker_app is None or hs.should_send_federation(): + federation_sender = hs.get_federation_sender() + self.current_token = federation_sender.get_current_token # type: ignore + self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore + else: + self.current_token = lambda: 0 # type: ignore + self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore super(FederationStream, self).__init__(hs) diff --git a/synapse/server.py b/synapse/server.py index 1b980371de..9426eb1672 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -85,6 +85,7 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool +from synapse.replication.tcp.resource import ReplicationStreamer from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, @@ -199,6 +200,7 @@ class HomeServer(object): "saml_handler", "event_client_serializer", "storage", + "replication_streamer", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -536,6 +538,9 @@ class HomeServer(object): def build_storage(self) -> Storage: return Storage(self, self.datastores) + def build_replication_streamer(self) -> ReplicationStreamer: + return ReplicationStreamer(self) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index d4c44dcc75..4dc5da3fe8 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -32,7 +32,29 @@ logger = logging.getLogger(__name__) CURRENT_STATE_CACHE_NAME = "cs_cache_fake" -class CacheInvalidationStore(SQLBaseStore): +class CacheInvalidationWorkerStore(SQLBaseStore): + def get_all_updated_caches(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = ( + "SELECT stream_id, cache_func, keys, invalidation_ts" + " FROM cache_invalidation_stream" + " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) + + +class CacheInvalidationStore(CacheInvalidationWorkerStore): async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore): }, ) - def get_all_updated_caches(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_caches_txn(txn): - # We purposefully don't bound by the current token, as we want to - # send across cache invalidations as quickly as possible. Cache - # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) - return txn.fetchall() - - return self.db.runInteraction( - "get_all_updated_caches", get_all_updated_caches_txn - ) - def get_cache_stream_token(self): if self._cache_id_gen: return self._cache_id_gen.get_current_token() diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 0613b49f4a..9a1178fb39 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore): "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) + def get_all_new_device_messages(self, last_pos, current_pos, limit): + """ + Args: + last_pos(int): + current_pos(int): + limit(int): + Returns: + A deferred list of rows from the device inbox + """ + if last_pos == current_pos: + return defer.succeed([]) + + def get_all_new_device_messages_txn(txn): + # We limit like this as we might have multiple rows per stream_id, and + # we want to make sure we always get all entries for any stream_id + # we return. + upper_pos = min(current_pos, last_pos + limit) + sql = ( + "SELECT max(stream_id), user_id" + " FROM device_inbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY user_id" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows = txn.fetchall() + + sql = ( + "SELECT max(stream_id), destination" + " FROM device_federation_outbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY destination" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows.extend(txn) + + # Order by ascending stream ordering + rows.sort() + + return rows + + return self.db.runInteraction( + "get_all_new_device_messages", get_all_new_device_messages_txn + ) + class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" @@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((user_id, device_id, stream_id, message_json)) txn.executemany(sql, rows) - - def get_all_new_device_messages(self, last_pos, current_pos, limit): - """ - Args: - last_pos(int): - current_pos(int): - limit(int): - Returns: - A deferred list of rows from the device inbox - """ - if last_pos == current_pos: - return defer.succeed([]) - - def get_all_new_device_messages_txn(txn): - # We limit like this as we might have multiple rows per stream_id, and - # we want to make sure we always get all entries for any stream_id - # we return. - upper_pos = min(current_pos, last_pos + limit) - sql = ( - "SELECT max(stream_id), user_id" - " FROM device_inbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY user_id" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows = txn.fetchall() - - sql = ( - "SELECT max(stream_id), destination" - " FROM device_federation_outbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY destination" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows.extend(txn) - - # Order by ascending stream ordering - rows.sort() - - return rows - - return self.db.runInteraction( - "get_all_new_device_messages", get_all_new_device_messages_txn - ) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index d593ef47b8..e71c23541d 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1267,104 +1267,6 @@ class EventsStore( ret = yield self.db.runInteraction("count_daily_active_rooms", _count) return ret - def get_current_backfill_token(self): - """The current minimum token that backfilled events have reached""" - return -self._backfill_id_gen.get_current_token() - - def get_current_events_token(self): - """The current maximum token that events have reached""" - return self._stream_id_gen.get_current_token() - - def get_all_new_forward_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_forward_event_rows(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_id, upper_bound)) - new_event_updates.extend(txn) - - return new_event_updates - - return self.db.runInteraction( - "get_all_new_forward_event_rows", get_all_new_forward_event_rows - ) - - def get_all_new_backfill_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_backfill_event_rows(txn): - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (-last_id, -current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_id, -upper_bound)) - new_event_updates.extend(txn.fetchall()) - - return new_event_updates - - return self.db.runInteraction( - "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows - ) - @cached(num_args=5, max_entries=10) def get_all_new_events( self, @@ -1850,22 +1752,6 @@ class EventsStore( return (int(res["topological_ordering"]), int(res["stream_ordering"])) - def get_all_updated_current_state_deltas(self, from_token, to_token, limit): - def get_all_updated_current_state_deltas_txn(txn): - sql = """ - SELECT stream_id, room_id, type, state_key, event_id - FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC LIMIT ? - """ - txn.execute(sql, (from_token, to_token, limit)) - return txn.fetchall() - - return self.db.runInteraction( - "get_all_updated_current_state_deltas", - get_all_updated_current_state_deltas_txn, - ) - def insert_labels_for_event_txn( self, txn, event_id, labels, room_id, topological_ordering ): diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 3013f49d32..16ea8948b1 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -963,3 +963,117 @@ class EventsWorkerStore(SQLBaseStore): complexity_v1 = round(state_events / 500, 2) return {"v1": complexity_v1} + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + return -self._backfill_id_gen.get_current_token() + + def get_current_events_token(self): + """The current maximum token that events have reached""" + return self._stream_id_gen.get_current_token() + + def get_all_new_forward_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_forward_event_rows(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_id, upper_bound)) + new_event_updates.extend(txn) + + return new_event_updates + + return self.db.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_all_new_backfill_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_backfill_event_rows(txn): + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (-last_id, -current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_id, -upper_bound)) + new_event_updates.extend(txn.fetchall()) + + return new_event_updates + + return self.db.runInteraction( + "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows + ) + + def get_all_updated_current_state_deltas(self, from_token, to_token, limit): + def get_all_updated_current_state_deltas_txn(txn): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ + txn.execute(sql, (from_token, to_token, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_updated_current_state_deltas", + get_all_updated_current_state_deltas_txn, + ) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index e6c10c6316..aaebe427d3 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore): return total_media_quarantined + def get_all_new_public_rooms(self, prev_id, current_id, limit): + def get_all_new_public_rooms(txn): + sql = """ + SELECT stream_id, room_id, visibility, appservice_id, network_id + FROM public_room_list_stream + WHERE stream_id > ? AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + + txn.execute(sql, (prev_id, current_id, limit)) + return txn.fetchall() + + if prev_id == current_id: + return defer.succeed([]) + + return self.db.runInteraction( + "get_all_new_public_rooms", get_all_new_public_rooms + ) + class RoomBackgroundUpdateStore(SQLBaseStore): REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" @@ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def get_all_new_public_rooms(self, prev_id, current_id, limit): - def get_all_new_public_rooms(txn): - sql = """ - SELECT stream_id, room_id, visibility, appservice_id, network_id - FROM public_room_list_stream - WHERE stream_id > ? AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - - txn.execute(sql, (prev_id, current_id, limit)) - return txn.fetchall() - - if prev_id == current_id: - return defer.succeed([]) - - return self.db.runInteraction( - "get_all_new_public_rooms", get_all_new_public_rooms - ) - @defer.inlineCallbacks def block_room(self, room_id, user_id): """Marks the room as blocked. Can be called multiple times. diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index e96ad4ca4e..a755fe2879 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from mock import Mock from synapse.replication.tcp.commands import ReplicateCommand @@ -29,19 +30,37 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # build a replication server server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = server_factory.streamer - server = server_factory.buildProtocol(None) + self.server = server_factory.buildProtocol(None) - # build a replication client, with a dummy handler - handler_factory = Mock() - self.test_handler = TestReplicationClientHandler() - self.test_handler.factory = handler_factory + self.test_handler = Mock(wraps=TestReplicationClientHandler()) self.client = ClientReplicationStreamProtocol( - "client", "test", clock, self.test_handler + hs, "client", "test", clock, self.test_handler, ) - # wire them together - self.client.makeConnection(FakeTransport(server, reactor)) - server.makeConnection(FakeTransport(self.client, reactor)) + self._client_transport = None + self._server_transport = None + + def reconnect(self): + if self._client_transport: + self.client.close() + + if self._server_transport: + self.server.close() + + self._client_transport = FakeTransport(self.server, self.reactor) + self.client.makeConnection(self._client_transport) + + self._server_transport = FakeTransport(self.client, self.reactor) + self.server.makeConnection(self._server_transport) + + def disconnect(self): + if self._client_transport: + self._client_transport = None + self.client.close() + + if self._server_transport: + self._server_transport = None + self.server.close() def replicate(self): """Tell the master side of replication that something has happened, and then @@ -50,19 +69,24 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.streamer.on_notifier_poke() self.pump(0.1) - def replicate_stream(self, stream, token="NOW"): + def replicate_stream(self): """Make the client end a REPLICATE command to set up a subscription to a stream""" - self.client.send_command(ReplicateCommand(stream, token)) + self.client.send_command(ReplicateCommand()) class TestReplicationClientHandler(object): """Drop-in for ReplicationClientHandler which just collects RDATA rows""" def __init__(self): - self.received_rdata_rows = [] + self.streams = set() + self._received_rdata_rows = [] def get_streams_to_replicate(self): - return {} + positions = {s: 0 for s in self.streams} + for stream, token, _ in self._received_rdata_rows: + if stream in self.streams: + positions[stream] = max(token, positions.get(stream, 0)) + return positions def get_currently_syncing_users(self): return [] @@ -73,6 +97,9 @@ class TestReplicationClientHandler(object): def finished_connecting(self): pass + async def on_position(self, stream_name, token): + """Called when we get new position data.""" + async def on_rdata(self, stream_name, token, rows): for r in rows: - self.received_rdata_rows.append((stream_name, token, r)) + self._received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index fa2493cad6..0ec0825a0e 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -17,30 +17,64 @@ from synapse.replication.tcp.streams._base import ReceiptsStream from tests.replication.tcp.streams._base import BaseStreamTestCase USER_ID = "@feeling:blue" -ROOM_ID = "!room:blue" -EVENT_ID = "$event:blue" class ReceiptsStreamTestCase(BaseStreamTestCase): def test_receipt(self): + self.reconnect() + # make the client subscribe to the receipts stream - self.replicate_stream("receipts", "NOW") + self.replicate_stream() + self.test_handler.streams.add("receipts") # tell the master to send a new receipt self.get_success( self.hs.get_datastore().insert_receipt( - ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1} + "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} ) ) self.replicate() # there should be one RDATA command - rdata_rows = self.test_handler.received_rdata_rows + self.test_handler.on_rdata.assert_called_once() + stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) - self.assertEqual(rdata_rows[0][0], "receipts") - row = rdata_rows[0][2] # type: ReceiptsStream.ReceiptsStreamRow - self.assertEqual(ROOM_ID, row.room_id) + row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + self.assertEqual("!room:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) - self.assertEqual(EVENT_ID, row.event_id) + self.assertEqual("$event:blue", row.event_id) self.assertEqual({"a": 1}, row.data) + + # Now let's disconnect and insert some data. + self.disconnect() + + self.test_handler.on_rdata.reset_mock() + + self.get_success( + self.hs.get_datastore().insert_receipt( + "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} + ) + ) + self.replicate() + + # Nothing should have happened as we are disconnected + self.test_handler.on_rdata.assert_not_called() + + self.reconnect() + self.pump(0.1) + + # We should now have caught up and get the missing data + self.test_handler.on_rdata.assert_called_once() + stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "receipts") + self.assertEqual(token, 3) + self.assertEqual(1, len(rdata_rows)) + + row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + self.assertEqual("!room2:blue", row.room_id) + self.assertEqual("m.read", row.receipt_type) + self.assertEqual(USER_ID, row.user_id) + self.assertEqual("$event2:foo", row.event_id) + self.assertEqual({"a": 2}, row.data) -- cgit 1.5.1 From 4f21c33be301b8ea6369039c3ad8baa51878e4d5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Mar 2020 16:37:24 +0100 Subject: Remove usage of "conn_id" for presence. (#7128) * Remove `conn_id` usage for UserSyncCommand. Each tcp replication connection is assigned a "conn_id", which is used to give an ID to a remotely connected worker. In a redis world, there will no longer be a one to one mapping between connection and instance, so instead we need to replace such usages with an ID generated by the remote instances and included in the replicaiton commands. This really only effects UserSyncCommand. * Add CLEAR_USER_SYNCS command that is sent on shutdown. This should help with the case where a synchrotron gets restarted gracefully, rather than rely on 5 minute timeout. --- changelog.d/7128.misc | 1 + docs/tcp_replication.md | 6 ++++++ synapse/app/generic_worker.py | 20 ++++++++++++++++---- synapse/replication/tcp/client.py | 6 ++++-- synapse/replication/tcp/commands.py | 36 ++++++++++++++++++++++++++++++++---- synapse/replication/tcp/protocol.py | 9 +++++++-- synapse/replication/tcp/resource.py | 17 +++++++---------- synapse/server.py | 11 +++++++++++ synapse/server.pyi | 2 ++ 9 files changed, 86 insertions(+), 22 deletions(-) create mode 100644 changelog.d/7128.misc (limited to 'synapse/replication/tcp/commands.py') diff --git a/changelog.d/7128.misc b/changelog.d/7128.misc new file mode 100644 index 0000000000..5703f6d2ec --- /dev/null +++ b/changelog.d/7128.misc @@ -0,0 +1 @@ +Add explicit `instance_id` for USER_SYNC commands and remove implicit `conn_id` usage. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index d4f7d9ec18..3be8e50c4c 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -198,6 +198,12 @@ Asks the server for the current position of all streams. A user has started or stopped syncing +#### CLEAR_USER_SYNC (C) + + The server should clear all associated user sync data from the worker. + + This is used when a worker is shutting down. + #### FEDERATION_ACK (C) Acknowledge receipt of some federation data diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index fba7ad9551..1ee266f7c5 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -65,6 +65,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import ( AccountDataStream, DeviceListsStream, @@ -124,7 +125,6 @@ from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole -from synapse.util.stringutils import random_string from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.generic_worker") @@ -233,6 +233,7 @@ class GenericWorkerPresence(object): self.user_to_num_current_syncs = {} self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self.instance_id = hs.get_instance_id() active_presence = self.store.take_presence_startup_info() self.user_to_current_state = {state.user_id: state for state in active_presence} @@ -245,13 +246,24 @@ class GenericWorkerPresence(object): self.send_stop_syncing, UPDATE_SYNCING_USERS_MS ) - self.process_id = random_string(16) - logger.info("Presence process_id is %r", self.process_id) + hs.get_reactor().addSystemEventTrigger( + "before", + "shutdown", + run_as_background_process, + "generic_presence.on_shutdown", + self._on_shutdown, + ) + + def _on_shutdown(self): + if self.hs.config.use_presence: + self.hs.get_tcp_replication().send_command( + ClearUserSyncsCommand(self.instance_id) + ) def send_user_sync(self, user_id, is_syncing, last_sync_ms): if self.hs.config.use_presence: self.hs.get_tcp_replication().send_user_sync( - user_id, is_syncing, last_sync_ms + self.instance_id, user_id, is_syncing, last_sync_ms ) def mark_as_coming_online(self, user_id): diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 7e7ad0f798..e86d9805f1 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -189,10 +189,12 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): """ self.send_command(FederationAckCommand(token)) - def send_user_sync(self, user_id, is_syncing, last_sync_ms): + def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): """Poke the master that a user has started/stopped syncing. """ - self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms)) + self.send_command( + UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + ) def send_remove_pusher(self, app_id, push_key, user_id): """Poke the master to remove a pusher for a user diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 5a6b734094..e4eec643f7 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -207,30 +207,32 @@ class UserSyncCommand(Command): Format:: - USER_SYNC + USER_SYNC Where is either "start" or "stop" """ NAME = "USER_SYNC" - def __init__(self, user_id, is_syncing, last_sync_ms): + def __init__(self, instance_id, user_id, is_syncing, last_sync_ms): + self.instance_id = instance_id self.user_id = user_id self.is_syncing = is_syncing self.last_sync_ms = last_sync_ms @classmethod def from_line(cls, line): - user_id, state, last_sync_ms = line.split(" ", 2) + instance_id, user_id, state, last_sync_ms = line.split(" ", 3) if state not in ("start", "end"): raise Exception("Invalid USER_SYNC state %r" % (state,)) - return cls(user_id, state == "start", int(last_sync_ms)) + return cls(instance_id, user_id, state == "start", int(last_sync_ms)) def to_line(self): return " ".join( ( + self.instance_id, self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms), @@ -238,6 +240,30 @@ class UserSyncCommand(Command): ) +class ClearUserSyncsCommand(Command): + """Sent by the client to inform the server that it should drop all + information about syncing users sent by the client. + + Mainly used when client is about to shut down. + + Format:: + + CLEAR_USER_SYNC + """ + + NAME = "CLEAR_USER_SYNC" + + def __init__(self, instance_id): + self.instance_id = instance_id + + @classmethod + def from_line(cls, line): + return cls(line) + + def to_line(self): + return self.instance_id + + class FederationAckCommand(Command): """Sent by the client when it has processed up to a given point in the federation stream. This allows the master to drop in-memory caches of the @@ -398,6 +424,7 @@ _COMMANDS = ( InvalidateCacheCommand, UserIpCommand, RemoteServerUpCommand, + ClearUserSyncsCommand, ) # type: Tuple[Type[Command], ...] # Map of command name to command type. @@ -420,6 +447,7 @@ VALID_CLIENT_COMMANDS = ( ReplicateCommand.NAME, PingCommand.NAME, UserSyncCommand.NAME, + ClearUserSyncsCommand.NAME, FederationAckCommand.NAME, RemovePusherCommand.NAME, InvalidateCacheCommand.NAME, diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index f81d2e2442..dae246825f 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -423,9 +423,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): async def on_USER_SYNC(self, cmd): await self.streamer.on_user_sync( - self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms + cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) + async def on_CLEAR_USER_SYNC(self, cmd): + await self.streamer.on_clear_user_syncs(cmd.instance_id) + async def on_REPLICATE(self, cmd): # Subscribe to all streams we're publishing to. for stream_name in self.streamer.streams_by_name: @@ -551,6 +554,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): ): BaseReplicationStreamProtocol.__init__(self, clock) + self.instance_id = hs.get_instance_id() + self.client_name = client_name self.server_name = server_name self.handler = handler @@ -580,7 +585,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): currently_syncing = self.handler.get_currently_syncing_users() now = self.clock.time_msec() for user_id in currently_syncing: - self.send_command(UserSyncCommand(user_id, True, now)) + self.send_command(UserSyncCommand(self.instance_id, user_id, True, now)) # We've now finished connecting to so inform the client handler self.handler.update_connection(self) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 4374e99e32..8b6067e20d 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -251,14 +251,19 @@ class ReplicationStreamer(object): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") - async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): + async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() await self.presence_handler.update_external_syncs_row( - conn_id, user_id, is_syncing, last_sync_ms + instance_id, user_id, is_syncing, last_sync_ms ) + async def on_clear_user_syncs(self, instance_id): + """A replication client wants us to drop all their UserSync data. + """ + await self.presence_handler.update_external_syncs_clear(instance_id) + @measure_func("repl.on_remove_pusher") async def on_remove_pusher(self, app_id, push_key, user_id): """A client has asked us to remove a pusher @@ -321,14 +326,6 @@ class ReplicationStreamer(object): except ValueError: pass - # We need to tell the presence handler that the connection has been - # lost so that it can handle any ongoing syncs on that connection. - run_as_background_process( - "update_external_syncs_clear", - self.presence_handler.update_external_syncs_clear, - connection.conn_id, - ) - def _batch_updates(updates): """Takes a list of updates of form [(token, row)] and sets the token to diff --git a/synapse/server.py b/synapse/server.py index c7ca2bda0d..cd86475d6b 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -103,6 +103,7 @@ from synapse.storage import DataStores, Storage from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor +from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) @@ -230,6 +231,8 @@ class HomeServer(object): self._listening_services = [] self.start_time = None + self.instance_id = random_string(5) + self.clock = Clock(reactor) self.distributor = Distributor() self.ratelimiter = Ratelimiter() @@ -242,6 +245,14 @@ class HomeServer(object): for depname in kwargs: setattr(self, depname, kwargs[depname]) + def get_instance_id(self): + """A unique ID for this synapse process instance. + + This is used to distinguish running instances in worker-based + deployments. + """ + return self.instance_id + def setup(self): logger.info("Setting up.") self.start_time = int(self.get_clock().time()) diff --git a/synapse/server.pyi b/synapse/server.pyi index 3844f0e12f..9d1dfa71e7 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -114,3 +114,5 @@ class HomeServer(object): pass def is_mine_id(self, domain_id: str) -> bool: pass + def get_instance_id(self) -> str: + pass -- cgit 1.5.1 From 6a519a0ca06f42c35d5c88f3fa5f62cfd1553905 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 7 Apr 2020 16:59:40 +0100 Subject: Remove vestigal references to SYNC replication command We've ripped pretty much all of this out: let's remove the remains. --- synapse/replication/tcp/commands.py | 10 ---------- synapse/replication/tcp/handler.py | 4 ---- 2 files changed, 14 deletions(-) (limited to 'synapse/replication/tcp/commands.py') diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index e4eec643f7..a07a01278a 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -289,14 +289,6 @@ class FederationAckCommand(Command): return str(self.token) -class SyncCommand(Command): - """Used for testing. The client protocol implementation allows waiting - on a SYNC command with a specified data. - """ - - NAME = "SYNC" - - class RemovePusherCommand(Command): """Sent by the client to request the master remove the given pusher. @@ -419,7 +411,6 @@ _COMMANDS = ( ReplicateCommand, UserSyncCommand, FederationAckCommand, - SyncCommand, RemovePusherCommand, InvalidateCacheCommand, UserIpCommand, @@ -437,7 +428,6 @@ VALID_SERVER_COMMANDS = ( PositionCommand.NAME, ErrorCommand.NAME, PingCommand.NAME, - SyncCommand.NAME, RemoteServerUpCommand.NAME, ) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index dd71d1bc34..2f5a299141 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -31,7 +31,6 @@ from synapse.replication.tcp.commands import ( RemoteServerUpCommand, RemovePusherCommand, ReplicateCommand, - SyncCommand, UserIpCommand, UserSyncCommand, ) @@ -281,9 +280,6 @@ class ReplicationCommandHandler: self._streams_connected.add(cmd.stream_name) - async def on_SYNC(self, cmd: SyncCommand): - pass - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) -- cgit 1.5.1 From c3e4b4edb270ad31321207125007ff41344510ed Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 7 Apr 2020 17:40:22 +0100 Subject: Fix warnings about not calling superclass constructor Separate `SimpleCommand` from `Command`, so that things which don't want to use the `data` property don't have to, and thus fix the warnings PyCharm was giving me about not calling `__init__` in the base class. --- synapse/replication/tcp/commands.py | 39 +++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 15 deletions(-) (limited to 'synapse/replication/tcp/commands.py') diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index a07a01278a..5ec89d0fb8 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -17,7 +17,7 @@ The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are allowed to be sent by which side. """ - +import abc import logging import platform from typing import Tuple, Type @@ -34,34 +34,29 @@ else: logger = logging.getLogger(__name__) -class Command(object): +class Command(metaclass=abc.ABCMeta): """The base command class. All subclasses must set the NAME variable which equates to the name of the command on the wire. A full command line on the wire is constructed from `NAME + " " + to_line()` - - The default implementation creates a command of form ` ` """ NAME = None # type: str - def __init__(self, data): - self.data = data - @classmethod + @abc.abstractmethod def from_line(cls, line): """Deserialises a line from the wire into this command. `line` does not include the command. """ - return cls(line) - def to_line(self): + @abc.abstractmethod + def to_line(self) -> str: """Serialises the comamnd for the wire. Does not include the command prefix. """ - return self.data def get_logcontext_id(self): """Get a suitable string for the logcontext when processing this command""" @@ -70,7 +65,21 @@ class Command(object): return self.NAME -class ServerCommand(Command): +class _SimpleCommand(Command): + """An implementation of Command whose argument is just a 'data' string.""" + + def __init__(self, data): + self.data = data + + @classmethod + def from_line(cls, line): + return cls(line) + + def to_line(self) -> str: + return self.data + + +class ServerCommand(_SimpleCommand): """Sent by the server on new connection and includes the server_name. Format:: @@ -155,7 +164,7 @@ class PositionCommand(Command): return " ".join((self.stream_name, str(self.token))) -class ErrorCommand(Command): +class ErrorCommand(_SimpleCommand): """Sent by either side if there was an ERROR. The data is a string describing the error. """ @@ -163,14 +172,14 @@ class ErrorCommand(Command): NAME = "ERROR" -class PingCommand(Command): +class PingCommand(_SimpleCommand): """Sent by either side as a keep alive. The data is arbitary (often timestamp) """ NAME = "PING" -class NameCommand(Command): +class NameCommand(_SimpleCommand): """Sent by client to inform the server of the client's identity. The data is the name """ @@ -387,7 +396,7 @@ class UserIpCommand(Command): ) -class RemoteServerUpCommand(Command): +class RemoteServerUpCommand(_SimpleCommand): """Sent when a worker has detected that a remote server is no longer "down" and retry timings should be reset. -- cgit 1.5.1 From 51f7eaf908a84fcaf231899e2bf1beae14ae72c0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 22 Apr 2020 13:07:41 +0100 Subject: Add ability to run replication protocol over redis. (#7040) This is configured via the `redis` config options. --- changelog.d/7040.feature | 1 + stubs/txredisapi.pyi | 40 +++++++ synapse/app/homeserver.py | 6 + synapse/config/homeserver.py | 2 + synapse/config/redis.py | 35 ++++++ synapse/python_dependencies.py | 1 + synapse/replication/tcp/client.py | 2 +- synapse/replication/tcp/commands.py | 18 +++ synapse/replication/tcp/handler.py | 50 +++++++-- synapse/replication/tcp/protocol.py | 38 ++----- synapse/replication/tcp/redis.py | 181 +++++++++++++++++++++++++++++++ tests/replication/slave/storage/_base.py | 4 +- 12 files changed, 342 insertions(+), 36 deletions(-) create mode 100644 changelog.d/7040.feature create mode 100644 stubs/txredisapi.pyi create mode 100644 synapse/config/redis.py create mode 100644 synapse/replication/tcp/redis.py (limited to 'synapse/replication/tcp/commands.py') diff --git a/changelog.d/7040.feature b/changelog.d/7040.feature new file mode 100644 index 0000000000..ce6140fdd1 --- /dev/null +++ b/changelog.d/7040.feature @@ -0,0 +1 @@ +Add support for running replication over Redis when using workers. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi new file mode 100644 index 0000000000..763d3fb404 --- /dev/null +++ b/stubs/txredisapi.pyi @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains *incomplete* type hints for txredisapi. +""" + +from typing import List, Optional, Union + +class RedisProtocol: + def publish(self, channel: str, message: bytes): ... + +class SubscriberProtocol: + def subscribe(self, channels: Union[str, List[str]]): ... + +def lazyConnection( + host: str = ..., + port: int = ..., + dbid: Optional[int] = ..., + reconnect: bool = ..., + charset: str = ..., + password: Optional[str] = ..., + connectTimeout: Optional[int] = ..., + replyTimeout: Optional[int] = ..., + convertNumbers: bool = ..., +) -> RedisProtocol: ... + +class SubscriberFactory: + def buildProtocol(self, addr): ... diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 49df63acd0..cbd1ea475a 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -273,6 +273,12 @@ class SynapseHomeServer(HomeServer): def start_listening(self, listeners): config = self.get_config() + if config.redis_enabled: + # If redis is enabled we connect via the replication command handler + # in the same way as the workers (since we're effectively a client + # rather than a server). + self.get_tcp_replication().start_replication(self) + for listener in listeners: if listener["type"] == "http": self._listening_services.extend(self._listener_http(config, listener)) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index b4bca08b20..be6c6afa74 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -31,6 +31,7 @@ from .password import PasswordConfig from .password_auth_providers import PasswordAuthProviderConfig from .push import PushConfig from .ratelimiting import RatelimitConfig +from .redis import RedisConfig from .registration import RegistrationConfig from .repository import ContentRepositoryConfig from .room_directory import RoomDirectoryConfig @@ -82,4 +83,5 @@ class HomeServerConfig(RootConfig): RoomDirectoryConfig, ThirdPartyRulesConfig, TracerConfig, + RedisConfig, ] diff --git a/synapse/config/redis.py b/synapse/config/redis.py new file mode 100644 index 0000000000..81a27619ec --- /dev/null +++ b/synapse/config/redis.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config +from synapse.python_dependencies import check_requirements + + +class RedisConfig(Config): + section = "redis" + + def read_config(self, config, **kwargs): + redis_config = config.get("redis", {}) + self.redis_enabled = redis_config.get("enabled", False) + + if not self.redis_enabled: + return + + check_requirements("redis") + + self.redis_host = redis_config.get("host", "localhost") + self.redis_port = redis_config.get("port", 6379) + self.redis_dbid = redis_config.get("dbid") + self.redis_password = redis_config.get("password") diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 8de8cb2c12..733c51b758 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -98,6 +98,7 @@ CONDITIONAL_REQUIREMENTS = { "sentry": ["sentry-sdk>=0.7.2"], "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"], "jwt": ["pyjwt>=1.6.4"], + "redis": ["txredisapi>=1.4.7"], } ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 700ae79158..2d07b8b2d0 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ReplicationClientFactory(ReconnectingClientFactory): +class DirectTcpReplicationClientFactory(ReconnectingClientFactory): """Factory for building connections to the master. Will reconnect if the connection is lost. diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 5ec89d0fb8..5002efe6a0 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -454,3 +454,21 @@ VALID_CLIENT_COMMANDS = ( ErrorCommand.NAME, RemoteServerUpCommand.NAME, ) + + +def parse_command_from_line(line: str) -> Command: + """Parses a command from a received line. + + Line should already be stripped of whitespace and be checked if blank. + """ + + idx = line.index(" ") + if idx >= 0: + cmd_name = line[:idx] + rest_of_line = line[idx + 1 :] + else: + cmd_name = line + rest_of_line = "" + + cmd_cls = COMMAND_MAP[cmd_name] + return cmd_cls.from_line(rest_of_line) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index e32e68e8c4..5b5ee2c13e 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -30,8 +30,10 @@ from typing import ( from prometheus_client import Counter +from twisted.internet.protocol import ReconnectingClientFactory + from synapse.metrics import LaterGauge -from synapse.replication.tcp.client import ReplicationClientFactory +from synapse.replication.tcp.client import DirectTcpReplicationClientFactory from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, Command, @@ -92,7 +94,7 @@ class ReplicationCommandHandler: self._pending_batches = {} # type: Dict[str, List[Any]] # The factory used to create connections. - self._factory = None # type: Optional[ReplicationClientFactory] + self._factory = None # type: Optional[ReconnectingClientFactory] # The currently connected connections. self._connections = [] # type: List[AbstractConnection] @@ -119,11 +121,45 @@ class ReplicationCommandHandler: """Helper method to start a replication connection to the remote server using TCP. """ - client_name = hs.config.worker_name - self._factory = ReplicationClientFactory(hs, client_name, self) - host = hs.config.worker_replication_host - port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self._factory) + if hs.config.redis.redis_enabled: + from synapse.replication.tcp.redis import ( + RedisDirectTcpReplicationClientFactory, + ) + import txredisapi + + logger.info( + "Connecting to redis (host=%r port=%r DBID=%r)", + hs.config.redis_host, + hs.config.redis_port, + hs.config.redis_dbid, + ) + + # We need two connections to redis, one for the subscription stream and + # one to send commands to (as you can't send further redis commands to a + # connection after SUBSCRIBE is called). + + # First create the connection for sending commands. + outbound_redis_connection = txredisapi.lazyConnection( + host=hs.config.redis_host, + port=hs.config.redis_port, + dbid=hs.config.redis_dbid, + password=hs.config.redis.redis_password, + reconnect=True, + ) + + # Now create the factory/connection for the subscription stream. + self._factory = RedisDirectTcpReplicationClientFactory( + hs, outbound_redis_connection + ) + hs.get_reactor().connectTCP( + hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, + ) + else: + client_name = hs.config.worker_name + self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + hs.get_reactor().connectTCP(host, port, self._factory) async def on_REPLICATE(self, cmd: ReplicateCommand): # We only want to announce positions by the writer of the streams. diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 9276ed2965..7240acb0a2 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -63,7 +63,6 @@ from twisted.python.failure import Failure from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( - COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, Command, @@ -72,6 +71,7 @@ from synapse.replication.tcp.commands import ( PingCommand, ReplicateCommand, ServerCommand, + parse_command_from_line, ) from synapse.types import Collection from synapse.util import Clock @@ -210,38 +210,24 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): linestr = line.decode("utf-8") - # split at the first " ", handling one-word commands - idx = linestr.index(" ") - if idx >= 0: - cmd_name = linestr[:idx] - rest_of_line = linestr[idx + 1 :] - else: - cmd_name = linestr - rest_of_line = "" + try: + cmd = parse_command_from_line(linestr) + except Exception as e: + logger.exception("[%s] failed to parse line: %r", self.id(), linestr) + self.send_error("failed to parse line: %r (%r):" % (e, linestr)) + return - if cmd_name not in self.VALID_INBOUND_COMMANDS: - logger.error("[%s] invalid command %s", self.id(), cmd_name) - self.send_error("invalid command: %s", cmd_name) + if cmd.NAME not in self.VALID_INBOUND_COMMANDS: + logger.error("[%s] invalid command %s", self.id(), cmd.NAME) + self.send_error("invalid command: %s", cmd.NAME) return self.last_received_command = self.clock.time_msec() - self.inbound_commands_counter[cmd_name] = ( - self.inbound_commands_counter[cmd_name] + 1 + self.inbound_commands_counter[cmd.NAME] = ( + self.inbound_commands_counter[cmd.NAME] + 1 ) - cmd_cls = COMMAND_MAP[cmd_name] - try: - cmd = cmd_cls.from_line(rest_of_line) - except Exception as e: - logger.exception( - "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line - ) - self.send_error( - "failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line) - ) - return - # Now lets try and call on_ function run_as_background_process( "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py new file mode 100644 index 0000000000..4c08425735 --- /dev/null +++ b/synapse/replication/tcp/redis.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +import txredisapi + +from synapse.logging.context import PreserveLoggingContext +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.tcp.commands import ( + Command, + ReplicateCommand, + parse_command_from_line, +) +from synapse.replication.tcp.protocol import AbstractConnection + +if TYPE_CHECKING: + from synapse.replication.tcp.handler import ReplicationCommandHandler + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): + """Connection to redis subscribed to replication stream. + + Parses incoming messages from redis into replication commands, and passes + them to `ReplicationCommandHandler` + + Due to the vagaries of `txredisapi` we don't want to have a custom + constructor, so instead we expect the defined attributes below to be set + immediately after initialisation. + + Attributes: + handler: The command handler to handle incoming commands. + stream_name: The *redis* stream name to subscribe to (not anything to + do with Synapse replication streams). + outbound_redis_connection: The connection to redis to use to send + commands. + """ + + handler = None # type: ReplicationCommandHandler + stream_name = None # type: str + outbound_redis_connection = None # type: txredisapi.RedisProtocol + + def connectionMade(self): + logger.info("Connected to redis instance") + self.subscribe(self.stream_name) + self.send_command(ReplicateCommand()) + + self.handler.new_connection(self) + + def messageReceived(self, pattern: str, channel: str, message: str): + """Received a message from redis. + """ + + if message.strip() == "": + # Ignore blank lines + return + + try: + cmd = parse_command_from_line(message) + except Exception: + logger.exception( + "[%s] failed to parse line: %r", message, + ) + return + + # Now lets try and call on_ function + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd + ) + + async def handle_command(self, cmd: Command): + """Handle a command we have received over the replication stream. + + By default delegates to on_, which should return an awaitable. + + Args: + cmd: received command + """ + handled = False + + # First call any command handlers on this instance. These are for redis + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + # Then call out to the handler. + cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + if not handled: + logger.warning("Unhandled command: %r", cmd) + + def connectionLost(self, reason): + logger.info("Lost connection to redis instance") + self.handler.lost_connection(self) + + def send_command(self, cmd: Command): + """Send a command if connection has been established. + + Args: + cmd (Command) + """ + string = "%s %s" % (cmd.NAME, cmd.to_line()) + if "\n" in string: + raise Exception("Unexpected newline in command: %r", string) + + encoded_string = string.encode("utf-8") + + async def _send(): + with PreserveLoggingContext(): + # Note that we use the other connection as we can't send + # commands using the subscription connection. + await self.outbound_redis_connection.publish( + self.stream_name, encoded_string + ) + + run_as_background_process("send-cmd", _send) + + +class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): + """This is a reconnecting factory that connects to redis and immediately + subscribes to a stream. + + Args: + hs + outbound_redis_connection: A connection to redis that will be used to + send outbound commands (this is seperate to the redis connection + used to subscribe). + """ + + maxDelay = 5 + continueTrying = True + protocol = RedisSubscriber + + def __init__( + self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol + ): + + super().__init__() + + # This sets the password on the RedisFactory base class (as + # SubscriberFactory constructor doesn't pass it through). + self.password = hs.config.redis.redis_password + + self.handler = hs.get_tcp_replication() + self.stream_name = hs.hostname + + self.outbound_redis_connection = outbound_redis_connection + + def buildProtocol(self, addr): + p = super().buildProtocol(addr) # type: RedisSubscriber + + # We do this here rather than add to the constructor of `RedisSubcriber` + # as to do so would involve overriding `buildProtocol` entirely, however + # the base method does some other things than just instantiating the + # protocol. + p.handler = self.handler + p.outbound_redis_connection = self.outbound_redis_connection + p.stream_name = self.stream_name + + return p diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 8902a5ab69..395c7d0306 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -16,7 +16,7 @@ from mock import Mock, NonCallableMock from synapse.replication.tcp.client import ( - ReplicationClientFactory, + DirectTcpReplicationClientFactory, ReplicationDataHandler, ) from synapse.replication.tcp.handler import ReplicationCommandHandler @@ -61,7 +61,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.slaved_store ) - client_factory = ReplicationClientFactory( + client_factory = DirectTcpReplicationClientFactory( self.hs, "client_name", self.replication_handler ) client_factory.handler = self.replication_handler -- cgit 1.5.1 From 82d8b1dd1f9712e6fe29da671b079c646a663a48 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 22 Apr 2020 14:34:31 +0100 Subject: Another go at fixing one-word commands (#7326) I messed this up last time I tried (#7239 / e13c6c7). --- changelog.d/7326.misc | 1 + synapse/replication/tcp/commands.py | 2 +- tests/replication/tcp/test_commands.py | 42 ++++++++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7326.misc create mode 100644 tests/replication/tcp/test_commands.py (limited to 'synapse/replication/tcp/commands.py') diff --git a/changelog.d/7326.misc b/changelog.d/7326.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7326.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 5002efe6a0..f26aee83cb 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -462,7 +462,7 @@ def parse_command_from_line(line: str) -> Command: Line should already be stripped of whitespace and be checked if blank. """ - idx = line.index(" ") + idx = line.find(" ") if idx >= 0: cmd_name = line[:idx] rest_of_line = line[idx + 1 :] diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py new file mode 100644 index 0000000000..3cbcb513cc --- /dev/null +++ b/tests/replication/tcp/test_commands.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.replication.tcp.commands import ( + RdataCommand, + ReplicateCommand, + parse_command_from_line, +) + +from tests.unittest import TestCase + + +class ParseCommandTestCase(TestCase): + def test_parse_one_word_command(self): + line = "REPLICATE" + cmd = parse_command_from_line(line) + self.assertIsInstance(cmd, ReplicateCommand) + + def test_parse_rdata(self): + line = 'RDATA events 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' + cmd = parse_command_from_line(line) + self.assertIsInstance(cmd, RdataCommand) + self.assertEqual(cmd.stream_name, "events") + self.assertEqual(cmd.token, 6287863) + + def test_parse_rdata_batch(self): + line = 'RDATA presence batch ["@foo:example.com", "online"]' + cmd = parse_command_from_line(line) + self.assertIsInstance(cmd, RdataCommand) + self.assertEqual(cmd.stream_name, "presence") + self.assertIsNone(cmd.token) -- cgit 1.5.1 From 71a1abb8a116372556fd577ff1b85c7cfbe3c2b3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 22 Apr 2020 22:39:04 +0100 Subject: Stop the master relaying USER_SYNC for other workers (#7318) Long story short: if we're handling presence on the current worker, we shouldn't be sending USER_SYNC commands over replication. In an attempt to figure out what is going on here, I ended up refactoring some bits of the presencehandler code, so the first 4 commits here are non-functional refactors to move this code slightly closer to sanity. (There's still plenty to do here :/). Suggest reviewing individual commits. Fixes (I hope) #7257. --- changelog.d/7318.misc | 1 + docs/tcp_replication.md | 6 +- synapse/api/constants.py | 2 + synapse/app/generic_worker.py | 85 ++++++++------- synapse/handlers/events.py | 20 ++-- synapse/handlers/initial_sync.py | 10 +- synapse/handlers/presence.py | 210 ++++++++++++++++++++---------------- synapse/replication/tcp/commands.py | 7 +- synapse/replication/tcp/handler.py | 15 +-- synapse/server.pyi | 2 +- 10 files changed, 199 insertions(+), 159 deletions(-) create mode 100644 changelog.d/7318.misc (limited to 'synapse/replication/tcp/commands.py') diff --git a/changelog.d/7318.misc b/changelog.d/7318.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7318.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index 3be8e50c4c..b922d9cf7e 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -196,7 +196,7 @@ Asks the server for the current position of all streams. #### USER_SYNC (C) - A user has started or stopped syncing + A user has started or stopped syncing on this process. #### CLEAR_USER_SYNC (C) @@ -216,10 +216,6 @@ Asks the server for the current position of all streams. Inform the server a cache should be invalidated -#### SYNC (S, C) - - Used exclusively in tests - ### REMOTE_SERVER_UP (S, C) Inform other processes that a remote server may have come back online. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index fda2c2e5bb..bcaf2c3600 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -97,6 +97,8 @@ class EventTypes(object): Retention = "m.room.retention" + Presence = "m.presence" + class RejectedReason(object): AUTH_ERROR = "auth_error" diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 37afd2f810..2a56fe0bd5 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -17,6 +17,9 @@ import contextlib import logging import sys +from typing import Dict, Iterable + +from typing_extensions import ContextManager from twisted.internet import defer, reactor from twisted.web.resource import NoResource @@ -38,14 +41,14 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.federation import send_queue from synapse.federation.transport.server import TransportLayerServer -from synapse.handlers.presence import PresenceHandler, get_interested_parties +from synapse.handlers.presence import BasePresenceHandler, get_interested_parties from synapse.http.server import JsonResource from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.replication.slave.storage._base import BaseSlavedStore, __func__ +from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore @@ -225,23 +228,32 @@ class KeyUploadServlet(RestServlet): return 200, {"one_time_key_counts": result} +class _NullContextManager(ContextManager[None]): + """A context manager which does nothing.""" + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + UPDATE_SYNCING_USERS_MS = 10 * 1000 -class GenericWorkerPresence(object): +class GenericWorkerPresence(BasePresenceHandler): def __init__(self, hs): + super().__init__(hs) self.hs = hs self.is_mine_id = hs.is_mine_id self.http_client = hs.get_simple_http_client() - self.store = hs.get_datastore() - self.user_to_num_current_syncs = {} - self.clock = hs.get_clock() + + self._presence_enabled = hs.config.use_presence + + # The number of ongoing syncs on this process, by user id. + # Empty if _presence_enabled is false. + self._user_to_num_current_syncs = {} # type: Dict[str, int] + self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() - active_presence = self.store.take_presence_startup_info() - self.user_to_current_state = {state.user_id: state for state in active_presence} - # user_id -> last_sync_ms. Lists the users that have stopped syncing # but we haven't notified the master of that yet self.users_going_offline = {} @@ -259,13 +271,13 @@ class GenericWorkerPresence(object): ) def _on_shutdown(self): - if self.hs.config.use_presence: + if self._presence_enabled: self.hs.get_tcp_replication().send_command( ClearUserSyncsCommand(self.instance_id) ) def send_user_sync(self, user_id, is_syncing, last_sync_ms): - if self.hs.config.use_presence: + if self._presence_enabled: self.hs.get_tcp_replication().send_user_sync( self.instance_id, user_id, is_syncing, last_sync_ms ) @@ -307,28 +319,33 @@ class GenericWorkerPresence(object): # TODO Hows this supposed to work? return defer.succeed(None) - get_states = __func__(PresenceHandler.get_states) - get_state = __func__(PresenceHandler.get_state) - current_state_for_users = __func__(PresenceHandler.current_state_for_users) + async def user_syncing( + self, user_id: str, affect_presence: bool + ) -> ContextManager[None]: + """Record that a user is syncing. + + Called by the sync and events servlets to record that a user has connected to + this worker and is waiting for some events. + """ + if not affect_presence or not self._presence_enabled: + return _NullContextManager() - def user_syncing(self, user_id, affect_presence): - if affect_presence: - curr_sync = self.user_to_num_current_syncs.get(user_id, 0) - self.user_to_num_current_syncs[user_id] = curr_sync + 1 + curr_sync = self._user_to_num_current_syncs.get(user_id, 0) + self._user_to_num_current_syncs[user_id] = curr_sync + 1 - # If we went from no in flight sync to some, notify replication - if self.user_to_num_current_syncs[user_id] == 1: - self.mark_as_coming_online(user_id) + # If we went from no in flight sync to some, notify replication + if self._user_to_num_current_syncs[user_id] == 1: + self.mark_as_coming_online(user_id) def _end(): # We check that the user_id is in user_to_num_current_syncs because # user_to_num_current_syncs may have been cleared if we are # shutting down. - if affect_presence and user_id in self.user_to_num_current_syncs: - self.user_to_num_current_syncs[user_id] -= 1 + if user_id in self._user_to_num_current_syncs: + self._user_to_num_current_syncs[user_id] -= 1 # If we went from one in flight sync to non, notify replication - if self.user_to_num_current_syncs[user_id] == 0: + if self._user_to_num_current_syncs[user_id] == 0: self.mark_as_going_offline(user_id) @contextlib.contextmanager @@ -338,7 +355,7 @@ class GenericWorkerPresence(object): finally: _end() - return defer.succeed(_user_syncing()) + return _user_syncing() @defer.inlineCallbacks def notify_from_replication(self, states, stream_id): @@ -373,15 +390,12 @@ class GenericWorkerPresence(object): stream_id = token yield self.notify_from_replication(states, stream_id) - def get_currently_syncing_users(self): - if self.hs.config.use_presence: - return [ - user_id - for user_id, count in self.user_to_num_current_syncs.items() - if count > 0 - ] - else: - return set() + def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + return [ + user_id + for user_id, count in self._user_to_num_current_syncs.items() + if count > 0 + ] class GenericWorkerTyping(object): @@ -625,8 +639,7 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): self.store = hs.get_datastore() self.typing_handler = hs.get_typing_handler() - # NB this is a SynchrotronPresence, not a normal PresenceHandler - self.presence_handler = hs.get_presence_handler() + self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence self.notifier = hs.get_notifier() self.notify_pushers = hs.config.start_pushers diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index ec18a42a68..71a89f09c7 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -19,6 +19,7 @@ import random from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase +from synapse.handlers.presence import format_user_presence_state from synapse.logging.utils import log_function from synapse.types import UserID from synapse.visibility import filter_events_for_client @@ -97,6 +98,8 @@ class EventStreamHandler(BaseHandler): explicit_room_id=room_id, ) + time_now = self.clock.time_msec() + # When the user joins a new room, or another user joins a currently # joined room, we need to send down presence for those users. to_add = [] @@ -112,19 +115,20 @@ class EventStreamHandler(BaseHandler): users = await self.state.get_current_users_in_room( event.room_id ) - states = await presence_handler.get_states(users, as_event=True) - to_add.extend(states) else: + users = [event.state_key] - ev = await presence_handler.get_state( - UserID.from_string(event.state_key), as_event=True - ) - to_add.append(ev) + states = await presence_handler.get_states(users) + to_add.extend( + { + "type": EventTypes.Presence, + "content": format_user_presence_state(state, time_now), + } + for state in states + ) events.extend(to_add) - time_now = self.clock.time_msec() - chunks = await self._event_serializer.serialize_events( events, time_now, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index b116500c7d..f88bad5f25 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -381,10 +381,16 @@ class InitialSyncHandler(BaseHandler): return [] states = await presence_handler.get_states( - [m.user_id for m in room_members], as_event=True + [m.user_id for m in room_members] ) - return states + return [ + { + "type": EventTypes.Presence, + "content": format_user_presence_state(s, time_now), + } + for s in states + ] async def get_receipts(): receipts = await self.store.get_linearized_receipts_for_room( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 6912165622..5cbefae177 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,10 +22,10 @@ The methods that define policy are: - PresenceHandler._handle_timeouts - should_notify """ - +import abc import logging from contextlib import contextmanager -from typing import Dict, List, Set +from typing import Dict, Iterable, List, Set from six import iteritems, itervalues @@ -41,7 +42,7 @@ from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState -from synapse.types import UserID, get_domain_from_id +from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure @@ -99,13 +100,106 @@ EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000 assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER -class PresenceHandler(object): +class BasePresenceHandler(abc.ABC): + """Parts of the PresenceHandler that are shared between workers and master""" + + def __init__(self, hs: "synapse.server.HomeServer"): + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + active_presence = self.store.take_presence_startup_info() + self.user_to_current_state = {state.user_id: state for state in active_presence} + + @abc.abstractmethod + async def user_syncing( + self, user_id: str, affect_presence: bool + ) -> ContextManager[None]: + """Returns a context manager that should surround any stream requests + from the user. + + This allows us to keep track of who is currently streaming and who isn't + without having to have timers outside of this module to avoid flickering + when users disconnect/reconnect. + + Args: + user_id: the user that is starting a sync + affect_presence: If false this function will be a no-op. + Useful for streams that are not associated with an actual + client that is being used by a user. + """ + + @abc.abstractmethod + def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + """Get an iterable of syncing users on this worker, to send to the presence handler + + This is called when a replication connection is established. It should return + a list of user ids, which are then sent as USER_SYNC commands to inform the + process handling presence about those users. + + Returns: + An iterable of user_id strings. + """ + + async def get_state(self, target_user: UserID) -> UserPresenceState: + results = await self.get_states([target_user.to_string()]) + return results[0] + + async def get_states( + self, target_user_ids: Iterable[str] + ) -> List[UserPresenceState]: + """Get the presence state for users.""" + + updates_d = await self.current_state_for_users(target_user_ids) + updates = list(updates_d.values()) + + for user_id in set(target_user_ids) - {u.user_id for u in updates}: + updates.append(UserPresenceState.default(user_id)) + + return updates + + async def current_state_for_users( + self, user_ids: Iterable[str] + ) -> Dict[str, UserPresenceState]: + """Get the current presence state for multiple users. + + Returns: + dict: `user_id` -> `UserPresenceState` + """ + states = { + user_id: self.user_to_current_state.get(user_id, None) + for user_id in user_ids + } + + missing = [user_id for user_id, state in iteritems(states) if not state] + if missing: + # There are things not in our in memory cache. Lets pull them out of + # the database. + res = await self.store.get_presence_for_users(missing) + states.update(res) + + missing = [user_id for user_id, state in iteritems(states) if not state] + if missing: + new = { + user_id: UserPresenceState.default(user_id) for user_id in missing + } + states.update(new) + self.user_to_current_state.update(new) + + return states + + @abc.abstractmethod + async def set_state( + self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False + ) -> None: + """Set the presence state of the user. """ + + +class PresenceHandler(BasePresenceHandler): def __init__(self, hs: "synapse.server.HomeServer"): + super().__init__(hs) self.hs = hs self.is_mine_id = hs.is_mine_id self.server_name = hs.hostname - self.clock = hs.get_clock() - self.store = hs.get_datastore() self.wheel_timer = WheelTimer() self.notifier = hs.get_notifier() self.federation = hs.get_federation_sender() @@ -115,13 +209,6 @@ class PresenceHandler(object): federation_registry.register_edu_handler("m.presence", self.incoming_presence) - active_presence = self.store.take_presence_startup_info() - - # A dictionary of the current state of users. This is prefilled with - # non-offline presence from the DB. We should fetch from the DB if - # we can't find a users presence in here. - self.user_to_current_state = {state.user_id: state for state in active_presence} - LaterGauge( "synapse_handlers_presence_user_to_current_state_size", "", @@ -130,7 +217,7 @@ class PresenceHandler(object): ) now = self.clock.time_msec() - for state in active_presence: + for state in self.user_to_current_state.values(): self.wheel_timer.insert( now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER ) @@ -361,10 +448,18 @@ class PresenceHandler(object): timers_fired_counter.inc(len(states)) + syncing_user_ids = { + user_id + for user_id, count in self.user_to_num_current_syncs.items() + if count + } + for user_ids in self.external_process_to_current_syncs.values(): + syncing_user_ids.update(user_ids) + changes = handle_timeouts( states, is_mine_fn=self.is_mine_id, - syncing_user_ids=self.get_currently_syncing_users(), + syncing_user_ids=syncing_user_ids, now=now, ) @@ -462,22 +557,9 @@ class PresenceHandler(object): return _user_syncing() - def get_currently_syncing_users(self): - """Get the set of user ids that are currently syncing on this HS. - Returns: - set(str): A set of user_id strings. - """ - if self.hs.config.use_presence: - syncing_user_ids = { - user_id - for user_id, count in self.user_to_num_current_syncs.items() - if count - } - for user_ids in self.external_process_to_current_syncs.values(): - syncing_user_ids.update(user_ids) - return syncing_user_ids - else: - return set() + def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + # since we are the process handling presence, there is nothing to do here. + return [] async def update_external_syncs_row( self, process_id, user_id, is_syncing, sync_time_msec @@ -554,34 +636,6 @@ class PresenceHandler(object): res = await self.current_state_for_users([user_id]) return res[user_id] - async def current_state_for_users(self, user_ids): - """Get the current presence state for multiple users. - - Returns: - dict: `user_id` -> `UserPresenceState` - """ - states = { - user_id: self.user_to_current_state.get(user_id, None) - for user_id in user_ids - } - - missing = [user_id for user_id, state in iteritems(states) if not state] - if missing: - # There are things not in our in memory cache. Lets pull them out of - # the database. - res = await self.store.get_presence_for_users(missing) - states.update(res) - - missing = [user_id for user_id, state in iteritems(states) if not state] - if missing: - new = { - user_id: UserPresenceState.default(user_id) for user_id in missing - } - states.update(new) - self.user_to_current_state.update(new) - - return states - async def _persist_and_notify(self, states): """Persist states in the database, poke the notifier and send to interested remote servers @@ -669,40 +723,6 @@ class PresenceHandler(object): federation_presence_counter.inc(len(updates)) await self._update_states(updates) - async def get_state(self, target_user, as_event=False): - results = await self.get_states([target_user.to_string()], as_event=as_event) - - return results[0] - - async def get_states(self, target_user_ids, as_event=False): - """Get the presence state for users. - - Args: - target_user_ids (list) - as_event (bool): Whether to format it as a client event or not. - - Returns: - list - """ - - updates = await self.current_state_for_users(target_user_ids) - updates = list(updates.values()) - - for user_id in set(target_user_ids) - {u.user_id for u in updates}: - updates.append(UserPresenceState.default(user_id)) - - now = self.clock.time_msec() - if as_event: - return [ - { - "type": "m.presence", - "content": format_user_presence_state(state, now), - } - for state in updates - ] - else: - return updates - async def set_state(self, target_user, state, ignore_status_msg=False): """Set the presence state of the user. """ @@ -889,7 +909,7 @@ class PresenceHandler(object): user_ids = await self.state.get_current_users_in_room(room_id) user_ids = list(filter(self.is_mine_id, user_ids)) - states = await self.current_state_for_users(user_ids) + states_d = await self.current_state_for_users(user_ids) # Filter out old presence, i.e. offline presence states where # the user hasn't been active for a week. We can change this @@ -899,7 +919,7 @@ class PresenceHandler(object): now = self.clock.time_msec() states = [ state - for state in states.values() + for state in states_d.values() if state.state != PresenceState.OFFLINE or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000 or state.status_msg is not None diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index f26aee83cb..c7880d4b63 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -210,7 +210,10 @@ class ReplicateCommand(Command): class UserSyncCommand(Command): """Sent by the client to inform the server that a user has started or - stopped syncing. Used to calculate presence on the master. + stopped syncing on this process. + + This is used by the process handling presence (typically the master) to + calculate who is online and who is not. Includes a timestamp of when the last user sync was. @@ -218,7 +221,7 @@ class UserSyncCommand(Command): USER_SYNC - Where is either "start" or "stop" + Where is either "start" or "end" """ NAME = "USER_SYNC" diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 5b5ee2c13e..0db5a3a24d 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -337,13 +337,6 @@ class ReplicationCommandHandler: if self._is_master: self._notifier.notify_remote_server_up(cmd.data) - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users. - """ - return self._presence_handler.get_currently_syncing_users() - def new_connection(self, connection: AbstractConnection): """Called when we have a new connection. """ @@ -361,9 +354,11 @@ class ReplicationCommandHandler: if self._factory: self._factory.resetDelay() - # Tell the server if we have any users currently syncing (should only - # happen on synchrotrons) - currently_syncing = self.get_currently_syncing_users() + # Tell the other end if we have any users currently syncing. + currently_syncing = ( + self._presence_handler.get_currently_syncing_users_for_replication() + ) + now = self._clock.time_msec() for user_id in currently_syncing: connection.send_command( diff --git a/synapse/server.pyi b/synapse/server.pyi index 9013e9bac9..f1a5717028 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -97,7 +97,7 @@ class HomeServer(object): pass def get_notifier(self) -> synapse.notifier.Notifier: pass - def get_presence_handler(self) -> synapse.handlers.presence.PresenceHandler: + def get_presence_handler(self) -> synapse.handlers.presence.BasePresenceHandler: pass def get_clock(self) -> synapse.util.Clock: pass -- cgit 1.5.1 From 37f6823f5b91f27b9dd8de8fc0e52d5ea889647c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 Apr 2020 16:23:08 +0100 Subject: Add instance name to RDATA/POSITION commands (#7364) This is primarily for allowing us to send those commands from workers, but for now simply allows us to ignore echoed RDATA/POSITION commands that we sent (we get echoes of sent commands when using redis). Currently we log a WARNING on the master process every time we receive an echoed RDATA. --- changelog.d/7364.misc | 1 + docs/tcp_replication.md | 41 +++++++++++++++++++------------- synapse/app/_base.py | 4 ++-- synapse/logging/opentracing.py | 23 ++++++++---------- synapse/replication/tcp/commands.py | 37 +++++++++++++++++++--------- synapse/replication/tcp/handler.py | 17 ++++++++++--- synapse/server.py | 13 ++++++++-- synapse/server.pyi | 2 ++ tests/replication/slave/storage/_base.py | 1 + tests/replication/tcp/test_commands.py | 6 +++-- 10 files changed, 95 insertions(+), 50 deletions(-) create mode 100644 changelog.d/7364.misc (limited to 'synapse/replication/tcp/commands.py') diff --git a/changelog.d/7364.misc b/changelog.d/7364.misc new file mode 100644 index 0000000000..bb5d727cf4 --- /dev/null +++ b/changelog.d/7364.misc @@ -0,0 +1 @@ +Add an `instance_name` to `RDATA` and `POSITION` replication commands. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index b922d9cf7e..ab2fffbfe4 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -15,15 +15,17 @@ example flow would be (where '>' indicates master to worker and > SERVER example.com < REPLICATE - > POSITION events 53 - > RDATA events 54 ["$foo1:bar.com", ...] - > RDATA events 55 ["$foo4:bar.com", ...] + > POSITION events master 53 + > RDATA events master 54 ["$foo1:bar.com", ...] + > RDATA events master 55 ["$foo4:bar.com", ...] The example shows the server accepting a new connection and sending its identity with the `SERVER` command, followed by the client server to respond with the position of all streams. The server then periodically sends `RDATA` commands -which have the format `RDATA `, where the format of -`` is defined by the individual streams. +which have the format `RDATA `, where +the format of `` is defined by the individual streams. The +`` is the name of the Synapse process that generated the data +(usually "master"). Error reporting happens by either the client or server sending an ERROR command, and usually the connection will be closed. @@ -52,7 +54,7 @@ The basic structure of the protocol is line based, where the initial word of each line specifies the command. The rest of the line is parsed based on the command. For example, the RDATA command is defined as: - RDATA + RDATA (Note that may contains spaces, but cannot contain newlines.) @@ -136,11 +138,11 @@ the wire: < NAME synapse.app.appservice < PING 1490197665618 < REPLICATE - > POSITION events 1 - > POSITION backfill 1 - > POSITION caches 1 - > RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] - > RDATA events 14 ["$149019767112vOHxz:localhost:8823", + > POSITION events master 1 + > POSITION backfill master 1 + > POSITION caches master 1 + > RDATA caches master 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] + > RDATA events master 14 ["$149019767112vOHxz:localhost:8823", "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null] < PING 1490197675618 > ERROR server stopping @@ -151,10 +153,10 @@ position without needing to send data with the `RDATA` command. An example of a batched set of `RDATA` is: - > RDATA caches batch ["get_user_by_id",["@test:localhost:8823"],1490197670513] - > RDATA caches batch ["get_user_by_id",["@test2:localhost:8823"],1490197670513] - > RDATA caches batch ["get_user_by_id",["@test3:localhost:8823"],1490197670513] - > RDATA caches 54 ["get_user_by_id",["@test4:localhost:8823"],1490197670513] + > RDATA caches master batch ["get_user_by_id",["@test:localhost:8823"],1490197670513] + > RDATA caches master batch ["get_user_by_id",["@test2:localhost:8823"],1490197670513] + > RDATA caches master batch ["get_user_by_id",["@test3:localhost:8823"],1490197670513] + > RDATA caches master 54 ["get_user_by_id",["@test4:localhost:8823"],1490197670513] In this case the client shouldn't advance their caches token until it sees the the last `RDATA`. @@ -178,6 +180,11 @@ client (C): updates, and if so then fetch them out of band. Sent in response to a REPLICATE command (but can happen at any time). + The POSITION command includes the source of the stream. Currently all streams + are written by a single process (usually "master"). If fetching missing + updates via HTTP API, rather than via the DB, then processes should make the + request to the appropriate process. + #### ERROR (S, C) There was an error @@ -234,12 +241,12 @@ Each individual cache invalidation results in a row being sent down replication, which includes the cache name (the name of the function) and they key to invalidate. For example: - > RDATA caches 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251] + > RDATA caches master 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251] Alternatively, an entire cache can be invalidated by sending down a `null` instead of the key. For example: - > RDATA caches 550953772 ["get_user_by_id", null, 1550574873252] + > RDATA caches master 550953772 ["get_user_by_id", null, 1550574873252] However, there are times when a number of caches need to be invalidated at the same time with the same key. To reduce traffic we batch those diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 4d84f4595a..628292b890 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -270,7 +270,7 @@ def start(hs, listeners=None): # Start the tracer synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa - hs.config + hs ) # It is now safe to start your Synapse. @@ -316,7 +316,7 @@ def setup_sentry(hs): scope.set_tag("matrix_server_name", hs.config.server_name) app = hs.config.worker_app if hs.config.worker_app else "synapse.app.homeserver" - name = hs.config.worker_name if hs.config.worker_name else "master" + name = hs.get_instance_name() scope.set_tag("worker_app", app) scope.set_tag("worker_name", name) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 0638cec429..5dddf57008 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -171,7 +171,7 @@ import logging import re import types from functools import wraps -from typing import Dict +from typing import TYPE_CHECKING, Dict from canonicaljson import json @@ -179,6 +179,9 @@ from twisted.internet import defer from synapse.config import ConfigError +if TYPE_CHECKING: + from synapse.server import HomeServer + # Helper class @@ -297,14 +300,11 @@ def _noop_context_manager(*args, **kwargs): # Setup -def init_tracer(config): +def init_tracer(hs: "HomeServer"): """Set the whitelists and initialise the JaegerClient tracer - - Args: - config (HomeserverConfig): The config used by the homeserver """ global opentracing - if not config.opentracer_enabled: + if not hs.config.opentracer_enabled: # We don't have a tracer opentracing = None return @@ -315,18 +315,15 @@ def init_tracer(config): "installed." ) - # Include the worker name - name = config.worker_name if config.worker_name else "master" - # Pull out the jaeger config if it was given. Otherwise set it to something sensible. # See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py - set_homeserver_whitelist(config.opentracer_whitelist) + set_homeserver_whitelist(hs.config.opentracer_whitelist) JaegerConfig( - config=config.jaeger_config, - service_name="{} {}".format(config.server_name, name), - scope_manager=LogContextScopeManager(config), + config=hs.config.jaeger_config, + service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()), + scope_manager=LogContextScopeManager(hs.config), ).initialize_tracer() diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index c7880d4b63..f58e384d17 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -95,7 +95,7 @@ class RdataCommand(Command): Format:: - RDATA + RDATA The `` may either be a numeric stream id OR "batch". The latter case is used to support sending multiple updates with the same stream ID. This @@ -105,33 +105,40 @@ class RdataCommand(Command): The client should batch all incoming RDATA with a token of "batch" (per stream_name) until it sees an RDATA with a numeric stream ID. + The `` is the source of the new data (usually "master"). + `` of "batch" maps to the instance variable `token` being None. An example of a batched series of RDATA:: - RDATA presence batch ["@foo:example.com", "online", ...] - RDATA presence batch ["@bar:example.com", "online", ...] - RDATA presence 59 ["@baz:example.com", "online", ...] + RDATA presence master batch ["@foo:example.com", "online", ...] + RDATA presence master batch ["@bar:example.com", "online", ...] + RDATA presence master 59 ["@baz:example.com", "online", ...] """ NAME = "RDATA" - def __init__(self, stream_name, token, row): + def __init__(self, stream_name, instance_name, token, row): self.stream_name = stream_name + self.instance_name = instance_name self.token = token self.row = row @classmethod def from_line(cls, line): - stream_name, token, row_json = line.split(" ", 2) + stream_name, instance_name, token, row_json = line.split(" ", 3) return cls( - stream_name, None if token == "batch" else int(token), json.loads(row_json) + stream_name, + instance_name, + None if token == "batch" else int(token), + json.loads(row_json), ) def to_line(self): return " ".join( ( self.stream_name, + self.instance_name, str(self.token) if self.token is not None else "batch", _json_encoder.encode(self.row), ) @@ -145,23 +152,31 @@ class PositionCommand(Command): """Sent by the server to tell the client the stream postition without needing to send an RDATA. + Format:: + + POSITION + On receipt of a POSITION command clients should check if they have missed any updates, and if so then fetch them out of band. + + The `` is the process that sent the command and is the source + of the stream. """ NAME = "POSITION" - def __init__(self, stream_name, token): + def __init__(self, stream_name, instance_name, token): self.stream_name = stream_name + self.instance_name = instance_name self.token = token @classmethod def from_line(cls, line): - stream_name, token = line.split(" ", 1) - return cls(stream_name, int(token)) + stream_name, instance_name, token = line.split(" ", 2) + return cls(stream_name, instance_name, int(token)) def to_line(self): - return " ".join((self.stream_name, str(self.token))) + return " ".join((self.stream_name, self.instance_name, str(self.token))) class ErrorCommand(_SimpleCommand): diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b8f49a8d0f..6f7054d5af 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -79,6 +79,7 @@ class ReplicationCommandHandler: self._notifier = hs.get_notifier() self._clock = hs.get_clock() self._instance_id = hs.get_instance_id() + self._instance_name = hs.get_instance_name() # Set of streams that we've caught up with. self._streams_connected = set() # type: Set[str] @@ -156,7 +157,7 @@ class ReplicationCommandHandler: hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, ) else: - client_name = hs.config.worker_name + client_name = hs.get_instance_name() self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port @@ -170,7 +171,9 @@ class ReplicationCommandHandler: for stream_name, stream in self._streams.items(): current_token = stream.current_token() - self.send_command(PositionCommand(stream_name, current_token)) + self.send_command( + PositionCommand(stream_name, self._instance_name, current_token) + ) async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): user_sync_counter.inc() @@ -235,6 +238,10 @@ class ReplicationCommandHandler: await self._server_notices_sender.on_user_ip(cmd.user_id) async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): + if cmd.instance_name == self._instance_name: + # Ignore RDATA that are just our own echoes + return + stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -286,6 +293,10 @@ class ReplicationCommandHandler: await self._replication_data_handler.on_rdata(stream_name, token, rows) async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): + if cmd.instance_name == self._instance_name: + # Ignore POSITION that are just our own echoes + return + stream = self._streams.get(cmd.stream_name) if not stream: logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) @@ -485,7 +496,7 @@ class ReplicationCommandHandler: We need to check if the client is interested in the stream or not """ - self.send_command(RdataCommand(stream_name, token, data)) + self.send_command(RdataCommand(stream_name, self._instance_name, token, data)) UpdateToken = TypeVar("UpdateToken") diff --git a/synapse/server.py b/synapse/server.py index 9d273c980c..bf97a16c09 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -234,7 +234,8 @@ class HomeServer(object): self._listening_services = [] self.start_time = None - self.instance_id = random_string(5) + self._instance_id = random_string(5) + self._instance_name = config.worker_name or "master" self.clock = Clock(reactor) self.distributor = Distributor() @@ -254,7 +255,15 @@ class HomeServer(object): This is used to distinguish running instances in worker-based deployments. """ - return self.instance_id + return self._instance_id + + def get_instance_name(self) -> str: + """A unique name for this synapse process. + + Used to identify the process over replication and in config. Does not + change over restarts. + """ + return self._instance_name def setup(self): logger.info("Setting up.") diff --git a/synapse/server.pyi b/synapse/server.pyi index fc5886f762..18043a2593 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -122,6 +122,8 @@ class HomeServer(object): pass def get_instance_id(self) -> str: pass + def get_instance_name(self) -> str: + pass def get_event_builder_factory(self) -> EventBuilderFactory: pass def get_storage(self) -> synapse.storage.Storage: diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 395c7d0306..1615dfab5e 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -57,6 +57,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): # We now do some gut wrenching so that we have a client that is based # off of the slave store rather than the main store. self.replication_handler = ReplicationCommandHandler(self.hs) + self.replication_handler._instance_name = "worker" self.replication_handler._replication_data_handler = ReplicationDataHandler( self.slaved_store ) diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py index 3cbcb513cc..7ddfd0a733 100644 --- a/tests/replication/tcp/test_commands.py +++ b/tests/replication/tcp/test_commands.py @@ -28,15 +28,17 @@ class ParseCommandTestCase(TestCase): self.assertIsInstance(cmd, ReplicateCommand) def test_parse_rdata(self): - line = 'RDATA events 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' + line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' cmd = parse_command_from_line(line) self.assertIsInstance(cmd, RdataCommand) self.assertEqual(cmd.stream_name, "events") + self.assertEqual(cmd.instance_name, "master") self.assertEqual(cmd.token, 6287863) def test_parse_rdata_batch(self): - line = 'RDATA presence batch ["@foo:example.com", "online"]' + line = 'RDATA presence master batch ["@foo:example.com", "online"]' cmd = parse_command_from_line(line) self.assertIsInstance(cmd, RdataCommand) self.assertEqual(cmd.stream_name, "presence") + self.assertEqual(cmd.instance_name, "master") self.assertIsNone(cmd.token) -- cgit 1.5.1 From d7983b63a6746d92225295f1e9d521f847cf8ba7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 7 May 2020 13:51:08 +0100 Subject: Support any process writing to cache invalidation stream. (#7436) --- changelog.d/7436.misc | 1 + docs/tcp_replication.md | 4 - scripts/synapse_port_db | 4 +- synapse/replication/slave/storage/_base.py | 50 +++---------- synapse/replication/slave/storage/account_data.py | 6 +- synapse/replication/slave/storage/deviceinbox.py | 6 +- synapse/replication/slave/storage/devices.py | 6 +- synapse/replication/slave/storage/events.py | 6 +- synapse/replication/slave/storage/groups.py | 6 +- synapse/replication/slave/storage/presence.py | 6 +- synapse/replication/slave/storage/push_rule.py | 6 +- synapse/replication/slave/storage/pushers.py | 6 +- synapse/replication/slave/storage/receipts.py | 6 +- synapse/replication/slave/storage/room.py | 4 +- synapse/replication/tcp/client.py | 6 +- synapse/replication/tcp/commands.py | 33 -------- synapse/replication/tcp/handler.py | 42 ++--------- synapse/replication/tcp/resource.py | 22 +++++- synapse/replication/tcp/streams/_base.py | 87 +++++++++++++++------- synapse/replication/tcp/streams/events.py | 4 +- synapse/replication/tcp/streams/federation.py | 12 ++- synapse/storage/_base.py | 3 + synapse/storage/data_stores/main/__init__.py | 15 +++- synapse/storage/data_stores/main/cache.py | 84 +++++++++++---------- .../schema/delta/58/05cache_instance.sql.postgres | 30 ++++++++ synapse/storage/prepare_database.py | 2 + 26 files changed, 226 insertions(+), 231 deletions(-) create mode 100644 changelog.d/7436.misc create mode 100644 synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres (limited to 'synapse/replication/tcp/commands.py') diff --git a/changelog.d/7436.misc b/changelog.d/7436.misc new file mode 100644 index 0000000000..f7c4514950 --- /dev/null +++ b/changelog.d/7436.misc @@ -0,0 +1 @@ +Support any process writing to cache invalidation stream. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index ab2fffbfe4..db318baa9d 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -219,10 +219,6 @@ Asks the server for the current position of all streams. Inform the server a pusher should be removed -#### INVALIDATE_CACHE (C) - - Inform the server a cache should be invalidated - ### REMOTE_SERVER_UP (S, C) Inform other processes that a remote server may have come back online. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index e8b698f3ff..acd9ac4b75 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -122,7 +122,7 @@ APPEND_ONLY_TABLES = [ "presence_stream", "push_rules_stream", "ex_outlier_stream", - "cache_invalidation_stream", + "cache_invalidation_stream_by_instance", "public_room_list_stream", "state_group_edges", "stream_ordering_to_exterm", @@ -188,7 +188,7 @@ class MockHomeserver: self.clock = Clock(reactor) self.config = config self.hostname = config.server_name - self.version_string = "Synapse/"+get_version_string(synapse) + self.version_string = "Synapse/" + get_version_string(synapse) def get_clock(self): return self.clock diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 5d7c8871a4..2904bd0235 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -18,14 +18,10 @@ from typing import Optional import six -from synapse.storage.data_stores.main.cache import ( - CURRENT_STATE_CACHE_NAME, - CacheInvalidationWorkerStore, -) +from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine - -from ._slaved_id_tracker import SlavedIdTracker +from synapse.storage.util.id_generators import MultiWriterIdGenerator logger = logging.getLogger(__name__) @@ -41,40 +37,16 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: Database, db_conn, hs): super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): - self._cache_id_gen = SlavedIdTracker( - db_conn, "cache_invalidation_stream", "stream_id" - ) # type: Optional[SlavedIdTracker] + self._cache_id_gen = MultiWriterIdGenerator( + db_conn, + database, + instance_name=hs.get_instance_name(), + table="cache_invalidation_stream_by_instance", + instance_column="instance_name", + id_column="stream_id", + sequence_name="cache_invalidation_stream_seq", + ) # type: Optional[MultiWriterIdGenerator] else: self._cache_id_gen = None self.hs = hs - - def get_cache_stream_token(self): - if self._cache_id_gen: - return self._cache_id_gen.get_current_token() - else: - return 0 - - def process_replication_rows(self, stream_name, token, rows): - if stream_name == "caches": - if self._cache_id_gen: - self._cache_id_gen.advance(token) - for row in rows: - if row.cache_func == CURRENT_STATE_CACHE_NAME: - if row.keys is None: - raise Exception( - "Can't send an 'invalidate all' for current state cache" - ) - - room_id = row.keys[0] - members_changed = set(row.keys[1:]) - self._invalidate_state_caches(room_id, members_changed) - else: - self._attempt_to_invalidate_cache(row.cache_func, row.keys) - - def _invalidate_cache_and_stream(self, txn, cache_func, keys): - txn.call_after(cache_func.invalidate, keys) - txn.call_after(self._send_invalidation_poke, cache_func, keys) - - def _send_invalidation_poke(self, cache_func, keys): - self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys) diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 65e54b1c71..2a4f5c7cfd 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -32,7 +32,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "tag_account_data": self._account_data_id_gen.advance(token) for row in rows: @@ -51,6 +51,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved (row.user_id, row.room_id, row.data_type) ) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedAccountDataStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index c923751e50..6e7fd259d4 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -43,7 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): expiry_ms=30 * 60 * 1000, ) - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "to_device": self._device_inbox_id_gen.advance(token) for row in rows: @@ -55,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): self._device_federation_outbox_stream_cache.entity_has_changed( row.entity, token ) - return super(SlavedDeviceInboxStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 58fb0eaae3..9d8067342f 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -48,7 +48,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto "DeviceListFederationStreamChangeCache", device_list_max ) - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(token) self._invalidate_caches_for_devices(token, rows) @@ -56,9 +56,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto self._device_list_id_gen.advance(token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedDeviceStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) def _invalidate_caches_for_devices(self, token, rows): for row in rows: diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 15011259df..b313720a4b 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -93,7 +93,7 @@ class SlavedEventStore( def get_room_min_stream_ordering(self): return self._backfill_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "events": self._stream_id_gen.advance(token) for row in rows: @@ -111,9 +111,7 @@ class SlavedEventStore( row.relates_to, backfilled=True, ) - return super(SlavedEventStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) def _process_event_stream_row(self, token, row): data = row.data diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 01bcf0e882..1851e7d525 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -37,12 +37,10 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "groups": self._group_updates_id_gen.advance(token) for row in rows: self._group_updates_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedGroupServerStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index fae3125072..bd79ba99be 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -41,12 +41,10 @@ class SlavedPresenceStore(BaseSlavedStore): def get_current_presence_token(self): return self._presence_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "presence": self._presence_id_gen.advance(token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) - return super(SlavedPresenceStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 6138796da4..5d5816d7eb 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -37,13 +37,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "push_rules": self._push_rules_stream_id_gen.advance(token) for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.push_rules_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedPushRuleStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 67be337945..cb78b49acb 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -31,9 +31,7 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): def get_pushers_stream_token(self): return self._pushers_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "pushers": self._pushers_id_gen.advance(token) - return super(SlavedPusherStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 993432edcb..be716cc558 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -51,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) self.get_receipts_for_room.invalidate((room_id, receipt_type)) - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "receipts": self._receipts_id_gen.advance(token) for row in rows: @@ -60,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): ) self._receipts_stream_cache.entity_has_changed(row.room_id, token) - return super(SlavedReceiptsStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 10dda8708f..8873bf37e5 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -30,8 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "public_rooms": self._public_room_id_gen.advance(token) - return super(RoomStore, self).process_replication_rows(stream_name, token, rows) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 3bbf3c3569..20cb8a654f 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -100,10 +100,10 @@ class ReplicationDataHandler: token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - self.store.process_replication_rows(stream_name, token, rows) + self.store.process_replication_rows(stream_name, instance_name, token, rows) - async def on_position(self, stream_name: str, token: int): - self.store.process_replication_rows(stream_name, token, []) + async def on_position(self, stream_name: str, instance_name: str, token: int): + self.store.process_replication_rows(stream_name, instance_name, token, []) def on_remote_server_up(self, server: str): """Called when get a new REMOTE_SERVER_UP command.""" diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index f58e384d17..c04f622816 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -341,37 +341,6 @@ class RemovePusherCommand(Command): return " ".join((self.app_id, self.push_key, self.user_id)) -class InvalidateCacheCommand(Command): - """Sent by the client to invalidate an upstream cache. - - THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE - NOT DISASTROUS IF WE DROP ON THE FLOOR. - - Mainly used to invalidate destination retry timing caches. - - Format:: - - INVALIDATE_CACHE - - Where is a json list. - """ - - NAME = "INVALIDATE_CACHE" - - def __init__(self, cache_func, keys): - self.cache_func = cache_func - self.keys = keys - - @classmethod - def from_line(cls, line): - cache_func, keys_json = line.split(" ", 1) - - return cls(cache_func, json.loads(keys_json)) - - def to_line(self): - return " ".join((self.cache_func, _json_encoder.encode(self.keys))) - - class UserIpCommand(Command): """Sent periodically when a worker sees activity from a client. @@ -439,7 +408,6 @@ _COMMANDS = ( UserSyncCommand, FederationAckCommand, RemovePusherCommand, - InvalidateCacheCommand, UserIpCommand, RemoteServerUpCommand, ClearUserSyncsCommand, @@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = ( ClearUserSyncsCommand.NAME, FederationAckCommand.NAME, RemovePusherCommand.NAME, - InvalidateCacheCommand.NAME, UserIpCommand.NAME, ErrorCommand.NAME, RemoteServerUpCommand.NAME, diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b14a3d9fca..7c5d6c76e7 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -15,18 +15,7 @@ # limitations under the License. import logging -from typing import ( - Any, - Callable, - Dict, - Iterable, - Iterator, - List, - Optional, - Set, - Tuple, - TypeVar, -) +from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar from prometheus_client import Counter @@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, Command, FederationAckCommand, - InvalidateCacheCommand, PositionCommand, RdataCommand, RemoteServerUpCommand, @@ -171,7 +159,7 @@ class ReplicationCommandHandler: return for stream_name, stream in self._streams.items(): - current_token = stream.current_token() + current_token = stream.current_token(self._instance_name) self.send_command( PositionCommand(stream_name, self._instance_name, current_token) ) @@ -210,18 +198,6 @@ class ReplicationCommandHandler: self._notifier.on_new_replication_data() - async def on_INVALIDATE_CACHE( - self, conn: AbstractConnection, cmd: InvalidateCacheCommand - ): - invalidate_cache_counter.inc() - - if self._is_master: - # We invalidate the cache locally, but then also stream that to other - # workers. - await self._store.invalidate_cache_and_stream( - cmd.cache_func, tuple(cmd.keys) - ) - async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): user_ip_cache_counter.inc() @@ -295,7 +271,7 @@ class ReplicationCommandHandler: rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - logger.debug("Received rdata %s -> %s", stream_name, token) + logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token) await self._replication_data_handler.on_rdata( stream_name, instance_name, token, rows ) @@ -326,7 +302,7 @@ class ReplicationCommandHandler: self._pending_batches.pop(stream_name, []) # Find where we previously streamed up to. - current_token = stream.current_token() + current_token = stream.current_token(cmd.instance_name) # If the position token matches our current token then we're up to # date and there's nothing to do. Otherwise, fetch all updates @@ -363,7 +339,9 @@ class ReplicationCommandHandler: logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) # We've now caught up to position sent to us, notify handler. - await self._replication_data_handler.on_position(stream_name, cmd.token) + await self._replication_data_handler.on_position( + cmd.stream_name, cmd.instance_name, cmd.token + ) self._streams_by_connection.setdefault(conn, set()).add(stream_name) @@ -491,12 +469,6 @@ class ReplicationCommandHandler: cmd = RemovePusherCommand(app_id, push_key, user_id) self.send_command(cmd) - def send_invalidate_cache(self, cache_func: Callable, keys: tuple): - """Poke the master to invalidate a cache. - """ - cmd = InvalidateCacheCommand(cache_func.__name__, keys) - self.send_command(cmd) - def send_user_ip( self, user_id: str, diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index b690abedad..002171ce7c 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -25,7 +25,12 @@ from twisted.internet.protocol import Factory from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol -from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream +from synapse.replication.tcp.streams import ( + STREAMS_MAP, + CachesStream, + FederationStream, + Stream, +) from synapse.util.metrics import Measure stream_updates_counter = Counter( @@ -71,11 +76,16 @@ class ReplicationStreamer(object): self.store = hs.get_datastore() self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self._instance_name = hs.get_instance_name() self._replication_torture_level = hs.config.replication_torture_level # Work out list of streams that this instance is the source of. self.streams = [] # type: List[Stream] + + # All workers can write to the cache invalidation stream. + self.streams.append(CachesStream(hs)) + if hs.config.worker_app is None: for stream in STREAMS_MAP.values(): if stream == FederationStream and hs.config.send_federation: @@ -83,6 +93,10 @@ class ReplicationStreamer(object): # has been disabled on the master. continue + if stream == CachesStream: + # We've already added it above. + continue + self.streams.append(stream(hs)) self.streams_by_name = {stream.NAME: stream for stream in self.streams} @@ -145,7 +159,9 @@ class ReplicationStreamer(object): random.shuffle(all_streams) for stream in all_streams: - if stream.last_token == stream.current_token(): + if stream.last_token == stream.current_token( + self._instance_name + ): continue if self._replication_torture_level: @@ -157,7 +173,7 @@ class ReplicationStreamer(object): "Getting stream: %s: %s -> %s", stream.NAME, stream.last_token, - stream.current_token(), + stream.current_token(self._instance_name), ) try: updates, current_token, limited = await stream.get_updates() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 084604e8b0..b48a6a3e91 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -95,20 +95,25 @@ class Stream(object): def __init__( self, local_instance_name: str, - current_token_function: Callable[[], Token], + current_token_function: Callable[[str], Token], update_function: UpdateFunction, ): """Instantiate a Stream - current_token_function and update_function are callbacks which should be - implemented by subclasses. + `current_token_function` and `update_function` are callbacks which + should be implemented by subclasses. - current_token_function is called to get the current token of the underlying - stream. It is only meaningful on the process that is the source of the - replication stream (ie, usually the master). + `current_token_function` takes an instance name, which is a writer to + the stream, and returns the position in the stream of the writer (as + viewed from the current process). On the writer process this is where + the writer has successfully written up to, whereas on other processes + this is the position which we have received updates up to over + replication. (Note that most streams have a single writer and so their + implementations ignore the instance name passed in). - update_function is called to get updates for this stream between a pair of - stream tokens. See the UpdateFunction type definition for more info. + `update_function` is called to get updates for this stream between a + pair of stream tokens. See the `UpdateFunction` type definition for more + info. Args: local_instance_name: The instance name of the current process @@ -120,13 +125,13 @@ class Stream(object): self.update_function = update_function # The token from which we last asked for updates - self.last_token = self.current_token() + self.last_token = self.current_token(self.local_instance_name) def discard_updates_and_advance(self): """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. """ - self.last_token = self.current_token() + self.last_token = self.current_token(self.local_instance_name) async def get_updates(self) -> StreamUpdateResult: """Gets all updates since the last time this function was called (or @@ -138,7 +143,7 @@ class Stream(object): position in stream, and `limited` is whether there are more updates to fetch. """ - current_token = self.current_token() + current_token = self.current_token(self.local_instance_name) updates, current_token, limited = await self.get_updates_since( self.local_instance_name, self.last_token, current_token ) @@ -170,6 +175,16 @@ class Stream(object): return updates, upto_token, limited +def current_token_without_instance( + current_token: Callable[[], int] +) -> Callable[[str], int]: + """Takes a current token callback function for a single writer stream + that doesn't take an instance name parameter and wraps it in a function that + does accept an instance name parameter but ignores it. + """ + return lambda instance_name: current_token() + + def db_query_to_update_function( query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] ) -> UpdateFunction: @@ -235,7 +250,7 @@ class BackfillStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_current_backfill_token, + current_token_without_instance(store.get_current_backfill_token), db_query_to_update_function(store.get_all_new_backfill_event_rows), ) @@ -271,7 +286,9 @@ class PresenceStream(Stream): update_function = make_http_update_function(hs, self.NAME) super().__init__( - hs.get_instance_name(), store.get_current_presence_token, update_function + hs.get_instance_name(), + current_token_without_instance(store.get_current_presence_token), + update_function, ) @@ -296,7 +313,9 @@ class TypingStream(Stream): update_function = make_http_update_function(hs, self.NAME) super().__init__( - hs.get_instance_name(), typing_handler.get_current_token, update_function + hs.get_instance_name(), + current_token_without_instance(typing_handler.get_current_token), + update_function, ) @@ -319,7 +338,7 @@ class ReceiptsStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_max_receipt_stream_id, + current_token_without_instance(store.get_max_receipt_stream_id), db_query_to_update_function(store.get_all_updated_receipts), ) @@ -339,7 +358,7 @@ class PushRulesStream(Stream): hs.get_instance_name(), self._current_token, self._update_function ) - def _current_token(self) -> int: + def _current_token(self, instance_name: str) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token @@ -373,7 +392,7 @@ class PushersStream(Stream): super().__init__( hs.get_instance_name(), - store.get_pushers_stream_token, + current_token_without_instance(store.get_pushers_stream_token), db_query_to_update_function(store.get_all_updated_pushers_rows), ) @@ -402,12 +421,26 @@ class CachesStream(Stream): ROW_TYPE = CachesStreamRow def __init__(self, hs): - store = hs.get_datastore() + self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_cache_stream_token, - db_query_to_update_function(store.get_all_updated_caches), + self.store.get_cache_stream_token, + self._update_function, + ) + + async def _update_function( + self, instance_name: str, from_token: int, upto_token: int, limit: int + ): + rows = await self.store.get_all_updated_caches( + instance_name, from_token, upto_token, limit ) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited class PublicRoomsStream(Stream): @@ -431,7 +464,7 @@ class PublicRoomsStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_current_public_room_stream_id, + current_token_without_instance(store.get_current_public_room_stream_id), db_query_to_update_function(store.get_all_new_public_rooms), ) @@ -452,7 +485,7 @@ class DeviceListsStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_device_stream_token, + current_token_without_instance(store.get_device_stream_token), db_query_to_update_function(store.get_all_device_list_changes_for_remotes), ) @@ -470,7 +503,7 @@ class ToDeviceStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_to_device_stream_token, + current_token_without_instance(store.get_to_device_stream_token), db_query_to_update_function(store.get_all_new_device_messages), ) @@ -490,7 +523,7 @@ class TagAccountDataStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_max_account_data_stream_id, + current_token_without_instance(store.get_max_account_data_stream_id), db_query_to_update_function(store.get_all_updated_tags), ) @@ -510,7 +543,7 @@ class AccountDataStream(Stream): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self.store.get_max_account_data_stream_id, + current_token_without_instance(self.store.get_max_account_data_stream_id), db_query_to_update_function(self._update_function), ) @@ -541,7 +574,7 @@ class GroupServerStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_group_stream_token, + current_token_without_instance(store.get_group_stream_token), db_query_to_update_function(store.get_all_groups_changes), ) @@ -559,7 +592,7 @@ class UserSignatureStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_device_stream_token, + current_token_without_instance(store.get_device_stream_token), db_query_to_update_function( store.get_all_user_signature_changes_for_remotes ), diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 890e75d827..f370390331 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -20,7 +20,7 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token +from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance """Handling of the 'events' replication stream @@ -119,7 +119,7 @@ class EventsStream(Stream): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self._store.get_current_events_token, + current_token_without_instance(self._store.get_current_events_token), self._update_function, ) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index b0505b8a2c..9bcd13b009 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,7 +15,11 @@ # limitations under the License. from collections import namedtuple -from synapse.replication.tcp.streams._base import Stream, make_http_update_function +from synapse.replication.tcp.streams._base import ( + Stream, + current_token_without_instance, + make_http_update_function, +) class FederationStream(Stream): @@ -41,7 +45,9 @@ class FederationStream(Stream): # will be a real FederationSender, which has stubs for current_token and # get_replication_rows.) federation_sender = hs.get_federation_sender() - current_token = federation_sender.get_current_token + current_token = current_token_without_instance( + federation_sender.get_current_token + ) update_function = federation_sender.get_replication_rows elif hs.should_send_federation(): @@ -58,7 +64,7 @@ class FederationStream(Stream): super().__init__(hs.get_instance_name(), current_token, update_function) @staticmethod - def _stub_current_token(): + def _stub_current_token(instance_name: str) -> int: # dummy current-token method for use on workers return 0 diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 13de5f1f62..59073c0a42 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -47,6 +47,9 @@ class SQLBaseStore(metaclass=ABCMeta): self.db = database self.rand = random.SystemRandom() + def process_replication_rows(self, stream_name, instance_name, token, rows): + pass + def _invalidate_state_caches(self, room_id, members_changed): """Invalidates caches that are based on the current state, but does not stream invalidations down replication. diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index ceba10882c..cd2a1f0461 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -26,13 +26,14 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( ChainedIdGenerator, IdGenerator, + MultiWriterIdGenerator, StreamIdGenerator, ) from synapse.util.caches.stream_change_cache import StreamChangeCache from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore -from .cache import CacheInvalidationStore +from .cache import CacheInvalidationWorkerStore from .client_ips import ClientIpStore from .deviceinbox import DeviceInboxStore from .devices import DeviceStore @@ -112,8 +113,8 @@ class DataStore( MonthlyActiveUsersStore, StatsStore, RelationsStore, - CacheInvalidationStore, UIAuthStore, + CacheInvalidationWorkerStore, ): def __init__(self, database: Database, db_conn, hs): self.hs = hs @@ -170,8 +171,14 @@ class DataStore( ) if isinstance(self.database_engine, PostgresEngine): - self._cache_id_gen = StreamIdGenerator( - db_conn, "cache_invalidation_stream", "stream_id" + self._cache_id_gen = MultiWriterIdGenerator( + db_conn, + database, + instance_name="master", + table="cache_invalidation_stream_by_instance", + instance_column="instance_name", + id_column="stream_id", + sequence_name="cache_invalidation_stream_seq", ) else: self._cache_id_gen = None diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index 4dc5da3fe8..342a87a46b 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -16,11 +16,10 @@ import itertools import logging -from typing import Any, Iterable, Optional, Tuple - -from twisted.internet import defer +from typing import Any, Iterable, Optional from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.util.iterutils import batch_iter @@ -33,47 +32,58 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" class CacheInvalidationWorkerStore(SQLBaseStore): - def get_all_updated_caches(self, last_id, current_id, limit): + def __init__(self, database: Database, db_conn, hs): + super().__init__(database, db_conn, hs) + + self._instance_name = hs.get_instance_name() + + async def get_all_updated_caches( + self, instance_name: str, last_id: int, current_id: int, limit: int + ): + """Fetches cache invalidation rows between the two given IDs written + by the given instance. Returns at most `limit` rows. + """ + if last_id == current_id: - return defer.succeed([]) + return [] def get_all_updated_caches_txn(txn): # We purposefully don't bound by the current token, as we want to # send across cache invalidations as quickly as possible. Cache # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) + sql = """ + SELECT stream_id, cache_func, keys, invalidation_ts + FROM cache_invalidation_stream_by_instance + WHERE stream_id > ? AND instance_name = ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, instance_name, limit)) return txn.fetchall() - return self.db.runInteraction( + return await self.db.runInteraction( "get_all_updated_caches", get_all_updated_caches_txn ) + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == "caches": + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) -class CacheInvalidationStore(CacheInvalidationWorkerStore): - async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): - """Invalidates the cache and adds it to the cache stream so slaves - will know to invalidate their caches. + for row in rows: + if row.cache_func == CURRENT_STATE_CACHE_NAME: + if row.keys is None: + raise Exception( + "Can't send an 'invalidate all' for current state cache" + ) - This should only be used to invalidate caches where slaves won't - otherwise know from other replication streams that the cache should - be invalidated. - """ - cache_func = getattr(self, cache_name, None) - if not cache_func: - return - - cache_func.invalidate(keys) - await self.runInteraction( - "invalidate_cache_and_stream", - self._send_invalidation_to_replication, - cache_func.__name__, - keys, - ) + room_id = row.keys[0] + members_changed = set(row.keys[1:]) + self._invalidate_state_caches(room_id, members_changed) + else: + self._attempt_to_invalidate_cache(row.cache_func, row.keys) + + super().process_replication_rows(stream_name, instance_name, token, rows) def _invalidate_cache_and_stream(self, txn, cache_func, keys): """Invalidates the cache and adds it to the cache stream so slaves @@ -147,10 +157,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore): # the transaction. However, we want to only get an ID when we want # to use it, here, so we need to call __enter__ manually, and have # __exit__ called after the transaction finishes. - ctx = self._cache_id_gen.get_next() - stream_id = ctx.__enter__() - txn.call_on_exception(ctx.__exit__, None, None, None) - txn.call_after(ctx.__exit__, None, None, None) + stream_id = self._cache_id_gen.get_next_txn(txn) txn.call_after(self.hs.get_notifier().on_new_replication_data) if keys is not None: @@ -158,17 +165,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore): self.db.simple_insert_txn( txn, - table="cache_invalidation_stream", + table="cache_invalidation_stream_by_instance", values={ "stream_id": stream_id, + "instance_name": self._instance_name, "cache_func": cache_name, "keys": keys, "invalidation_ts": self.clock.time_msec(), }, ) - def get_cache_stream_token(self): + def get_cache_stream_token(self, instance_name): if self._cache_id_gen: - return self._cache_id_gen.get_current_token() + return self._cache_id_gen.get_current_token(instance_name) else: return 0 diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres new file mode 100644 index 0000000000..aa46eb0e10 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres @@ -0,0 +1,30 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We keep the old table here to enable us to roll back. It doesn't matter +-- that we have dropped all the data here. +TRUNCATE cache_invalidation_stream; + +CREATE TABLE cache_invalidation_stream_by_instance ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + cache_func TEXT NOT NULL, + keys TEXT[], + invalidation_ts BIGINT +); + +CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id); + +CREATE SEQUENCE cache_invalidation_stream_seq; diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 1712932f31..640f242584 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -29,6 +29,8 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. +# XXX: If you're about to bump this to 59 (or higher) please create an update +# that drops the unused `cache_invalidation_stream` table, as per #7436! SCHEMA_VERSION = 58 dir_path = os.path.abspath(os.path.dirname(__file__)) -- cgit 1.5.1