From e8b68a4e4b439065536c281d8997af85880f6ee2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 14 Jan 2020 14:08:06 +0000 Subject: Fixup synapse.replication to pass mypy checks (#6667) --- synapse/replication/tcp/streams/_base.py | 57 ++++++++++++++------------- synapse/replication/tcp/streams/events.py | 16 +++++--- synapse/replication/tcp/streams/federation.py | 4 +- 3 files changed, 42 insertions(+), 35 deletions(-) (limited to 'synapse/replication/tcp/streams') diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 8512923eae..4ab0334fc1 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,10 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. - import itertools import logging from collections import namedtuple +from typing import Any from twisted.internet import defer @@ -104,8 +104,9 @@ class Stream(object): time it was called up until the point `advance_current_token` was called. """ - NAME = None # The name of the stream - ROW_TYPE = None # The type of the row. Used by the default impl of parse_row. + NAME = None # type: str # The name of the stream + # The type of the row. Used by the default impl of parse_row. + ROW_TYPE = None # type: Any _LIMITED = True # Whether the update function takes a limit @classmethod @@ -231,8 +232,8 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_current_backfill_token - self.update_function = store.get_all_new_backfill_event_rows + self.current_token = store.get_current_backfill_token # type: ignore + self.update_function = store.get_all_new_backfill_event_rows # type: ignore super(BackfillStream, self).__init__(hs) @@ -246,8 +247,8 @@ class PresenceStream(Stream): store = hs.get_datastore() presence_handler = hs.get_presence_handler() - self.current_token = store.get_current_presence_token - self.update_function = presence_handler.get_all_presence_updates + self.current_token = store.get_current_presence_token # type: ignore + self.update_function = presence_handler.get_all_presence_updates # type: ignore super(PresenceStream, self).__init__(hs) @@ -260,8 +261,8 @@ class TypingStream(Stream): def __init__(self, hs): typing_handler = hs.get_typing_handler() - self.current_token = typing_handler.get_current_token - self.update_function = typing_handler.get_all_typing_updates + self.current_token = typing_handler.get_current_token # type: ignore + self.update_function = typing_handler.get_all_typing_updates # type: ignore super(TypingStream, self).__init__(hs) @@ -273,8 +274,8 @@ class ReceiptsStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_max_receipt_stream_id - self.update_function = store.get_all_updated_receipts + self.current_token = store.get_max_receipt_stream_id # type: ignore + self.update_function = store.get_all_updated_receipts # type: ignore super(ReceiptsStream, self).__init__(hs) @@ -310,8 +311,8 @@ class PushersStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_pushers_stream_token - self.update_function = store.get_all_updated_pushers_rows + self.current_token = store.get_pushers_stream_token # type: ignore + self.update_function = store.get_all_updated_pushers_rows # type: ignore super(PushersStream, self).__init__(hs) @@ -327,8 +328,8 @@ class CachesStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_cache_stream_token - self.update_function = store.get_all_updated_caches + self.current_token = store.get_cache_stream_token # type: ignore + self.update_function = store.get_all_updated_caches # type: ignore super(CachesStream, self).__init__(hs) @@ -343,8 +344,8 @@ class PublicRoomsStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_current_public_room_stream_id - self.update_function = store.get_all_new_public_rooms + self.current_token = store.get_current_public_room_stream_id # type: ignore + self.update_function = store.get_all_new_public_rooms # type: ignore super(PublicRoomsStream, self).__init__(hs) @@ -360,8 +361,8 @@ class DeviceListsStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_device_stream_token - self.update_function = store.get_all_device_list_changes_for_remotes + self.current_token = store.get_device_stream_token # type: ignore + self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore super(DeviceListsStream, self).__init__(hs) @@ -376,8 +377,8 @@ class ToDeviceStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_to_device_stream_token - self.update_function = store.get_all_new_device_messages + self.current_token = store.get_to_device_stream_token # type: ignore + self.update_function = store.get_all_new_device_messages # type: ignore super(ToDeviceStream, self).__init__(hs) @@ -392,8 +393,8 @@ class TagAccountDataStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_max_account_data_stream_id - self.update_function = store.get_all_updated_tags + self.current_token = store.get_max_account_data_stream_id # type: ignore + self.update_function = store.get_all_updated_tags # type: ignore super(TagAccountDataStream, self).__init__(hs) @@ -408,7 +409,7 @@ class AccountDataStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() - self.current_token = self.store.get_max_account_data_stream_id + self.current_token = self.store.get_max_account_data_stream_id # type: ignore super(AccountDataStream, self).__init__(hs) @@ -434,8 +435,8 @@ class GroupServerStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_group_stream_token - self.update_function = store.get_all_groups_changes + self.current_token = store.get_group_stream_token # type: ignore + self.update_function = store.get_all_groups_changes # type: ignore super(GroupServerStream, self).__init__(hs) @@ -451,7 +452,7 @@ class UserSignatureStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_device_stream_token - self.update_function = store.get_all_user_signature_changes_for_remotes + self.current_token = store.get_device_stream_token # type: ignore + self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore super(UserSignatureStream, self).__init__(hs) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index d97669c886..0843e5aa90 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -13,7 +13,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import heapq +from typing import Tuple, Type import attr @@ -63,7 +65,8 @@ class BaseEventsStreamRow(object): Specifies how to identify, serialize and deserialize the different types. """ - TypeId = None # Unique string that ids the type. Must be overriden in sub classes. + # Unique string that ids the type. Must be overriden in sub classes. + TypeId = None # type: str @classmethod def from_data(cls, data): @@ -99,9 +102,12 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow): event_id = attr.ib() # str, optional -TypeToRow = { - Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow) -} +_EventRows = ( + EventsStreamEventRow, + EventsStreamCurrentStateRow, +) # type: Tuple[Type[BaseEventsStreamRow], ...] + +TypeToRow = {Row.TypeId: Row for Row in _EventRows} class EventsStream(Stream): @@ -112,7 +118,7 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() - self.current_token = self._store.get_current_events_token + self.current_token = self._store.get_current_events_token # type: ignore super(EventsStream, self).__init__(hs) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index dc2484109d..615f3dc9ac 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -37,7 +37,7 @@ class FederationStream(Stream): def __init__(self, hs): federation_sender = hs.get_federation_sender() - self.current_token = federation_sender.get_current_token - self.update_function = federation_sender.get_replication_rows + self.current_token = federation_sender.get_current_token # type: ignore + self.update_function = federation_sender.get_replication_rows # type: ignore super(FederationStream, self).__init__(hs) -- cgit 1.5.1 From 48c3a96886de64f3141ad68b8163cd2fc0c197ff Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 Jan 2020 09:16:12 +0000 Subject: Port synapse.replication.tcp to async/await (#6666) * Port synapse.replication.tcp to async/await * Newsfile * Correctly document type of on_ functions as async * Don't be overenthusiastic with the asyncing.... --- changelog.d/6666.misc | 1 + synapse/app/admin_cmd.py | 3 +- synapse/app/appservice.py | 5 +-- synapse/app/federation_sender.py | 5 +-- synapse/app/pusher.py | 5 +-- synapse/app/synchrotron.py | 5 +-- synapse/app/user_dir.py | 5 +-- synapse/federation/send_queue.py | 4 +- synapse/handlers/typing.py | 2 +- synapse/replication/tcp/client.py | 11 ++--- synapse/replication/tcp/protocol.py | 72 ++++++++++++++----------------- synapse/replication/tcp/resource.py | 31 ++++++------- synapse/replication/tcp/streams/_base.py | 25 +++++------ synapse/replication/tcp/streams/events.py | 9 ++-- tests/replication/tcp/streams/_base.py | 2 +- 15 files changed, 80 insertions(+), 105 deletions(-) create mode 100644 changelog.d/6666.misc (limited to 'synapse/replication/tcp/streams') diff --git a/changelog.d/6666.misc b/changelog.d/6666.misc new file mode 100644 index 0000000000..e79c23d2d2 --- /dev/null +++ b/changelog.d/6666.misc @@ -0,0 +1 @@ +Port `synapse.replication.tcp` to async/await. diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 8e36bc57d3..1c7c6ec0c8 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -84,8 +84,7 @@ class AdminCmdServer(HomeServer): class AdminCmdReplicationHandler(ReplicationClientHandler): - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): pass def get_streams_to_replicate(self): diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index e82e0f11e3..2217d4a4fb 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -115,9 +115,8 @@ class ASReplicationHandler(ReplicationClientHandler): super(ASReplicationHandler, self).__init__(hs.get_datastore()) self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, token, rows): + await super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) if stream_name == "events": max_stream_id = self.store.get_room_max_stream_ordering() diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 83c436229c..a57cf991ac 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -145,9 +145,8 @@ class FederationSenderReplicationHandler(ReplicationClientHandler): super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) self.send_handler = FederationSenderHandler(hs, self) - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(FederationSenderReplicationHandler, self).on_rdata( + async def on_rdata(self, stream_name, token, rows): + await super(FederationSenderReplicationHandler, self).on_rdata( stream_name, token, rows ) self.send_handler.process_replication_rows(stream_name, token, rows) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 09e639040a..e46b6ac598 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -141,9 +141,8 @@ class PusherReplicationHandler(ReplicationClientHandler): self.pusher_pool = hs.get_pusherpool() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, token, rows): + await super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) run_in_background(self.poke_pushers, stream_name, token, rows) @defer.inlineCallbacks diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 03031ee34d..3218da07bd 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -358,9 +358,8 @@ class SyncReplicationHandler(ReplicationClientHandler): self.presence_handler = hs.get_presence_handler() self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, token, rows): + await super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) run_in_background(self.process_and_notify, stream_name, token, rows) def get_streams_to_replicate(self): diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 1257098f92..ba536d6f04 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -172,9 +172,8 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler): super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore()) self.user_directory = hs.get_user_directory_handler() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(UserDirectoryReplicationHandler, self).on_rdata( + async def on_rdata(self, stream_name, token, rows): + await super(UserDirectoryReplicationHandler, self).on_rdata( stream_name, token, rows ) if stream_name == EventsStream.NAME: diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index ced4925a98..174f6e42be 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -259,7 +259,9 @@ class FederationRemoteSendQueue(object): def federation_ack(self, token): self._clear_queue_before_pos(token) - def get_replication_rows(self, from_token, to_token, limit, federation_ack=None): + async def get_replication_rows( + self, from_token, to_token, limit, federation_ack=None + ): """Get rows to be sent over federation between the two tokens Args: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index b635c339ed..d5ca9cb07b 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -257,7 +257,7 @@ class TypingHandler(object): "typing_key", self._latest_room_serial, rooms=[member.room_id] ) - def get_all_typing_updates(self, last_id, current_id): + async def get_all_typing_updates(self, last_id, current_id): if last_id == current_id: return [] diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index aa7fd90e26..52a0aefe68 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -110,7 +110,7 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self.factory) - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to @@ -121,20 +121,17 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): token (int): stream token for this batch of rows rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. - - Returns: - Deferred|None """ logger.debug("Received rdata %s -> %s", stream_name, token) - return self.store.process_replication_rows(stream_name, token, rows) + self.store.process_replication_rows(stream_name, token, rows) - def on_position(self, stream_name, token): + async def on_position(self, stream_name, token): """Called when we get new position data. By default this just pokes the slave store. Can be overriden in subclasses to handle more. """ - return self.store.process_replication_rows(stream_name, token, []) + self.store.process_replication_rows(stream_name, token, []) def on_sync(self, data): """When we received a SYNC we wake up any deferreds that were waiting diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index db0353c996..5f4bdf84d2 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -81,12 +81,11 @@ from synapse.replication.tcp.commands import ( SyncCommand, UserSyncCommand, ) +from synapse.replication.tcp.streams import STREAMS_MAP from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string -from .streams import STREAMS_MAP - connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) @@ -241,19 +240,16 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd ) - def handle_command(self, cmd): + async def handle_command(self, cmd: Command): """Handle a command we have received over the replication stream. - By default delegates to on_ + By default delegates to on_, which should return an awaitable. Args: - cmd (synapse.replication.tcp.commands.Command): received command - - Returns: - Deferred + cmd: received command """ handler = getattr(self, "on_%s" % (cmd.NAME,)) - return handler(cmd) + await handler(cmd) def close(self): logger.warning("[%s] Closing connection", self.id()) @@ -326,10 +322,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): for cmd in pending: self.send_command(cmd) - def on_PING(self, line): + async def on_PING(self, line): self.received_ping = True - def on_ERROR(self, cmd): + async def on_ERROR(self, cmd): logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) def pauseProducing(self): @@ -429,16 +425,16 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): BaseReplicationStreamProtocol.connectionMade(self) self.streamer.new_connection(self) - def on_NAME(self, cmd): + async def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data - def on_USER_SYNC(self, cmd): - return self.streamer.on_user_sync( + async def on_USER_SYNC(self, cmd): + await self.streamer.on_user_sync( self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) - def on_REPLICATE(self, cmd): + async def on_REPLICATE(self, cmd): stream_name = cmd.stream_name token = cmd.token @@ -449,23 +445,23 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): for stream in iterkeys(self.streamer.streams_by_name) ] - return make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) else: - return self.subscribe_to_stream(stream_name, token) + await self.subscribe_to_stream(stream_name, token) - def on_FEDERATION_ACK(self, cmd): - return self.streamer.federation_ack(cmd.token) + async def on_FEDERATION_ACK(self, cmd): + self.streamer.federation_ack(cmd.token) - def on_REMOVE_PUSHER(self, cmd): - return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) + async def on_REMOVE_PUSHER(self, cmd): + await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) - def on_INVALIDATE_CACHE(self, cmd): - return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + async def on_INVALIDATE_CACHE(self, cmd): + self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) - def on_USER_IP(self, cmd): - return self.streamer.on_user_ip( + async def on_USER_IP(self, cmd): + self.streamer.on_user_ip( cmd.user_id, cmd.access_token, cmd.ip, @@ -474,8 +470,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): cmd.last_seen, ) - @defer.inlineCallbacks - def subscribe_to_stream(self, stream_name, token): + async def subscribe_to_stream(self, stream_name, token): """Subscribe the remote to a stream. This invloves checking if they've missed anything and sending those @@ -487,7 +482,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): try: # Get missing updates - updates, current_token = yield self.streamer.get_stream_updates( + updates, current_token = await self.streamer.get_stream_updates( stream_name, token ) @@ -572,7 +567,7 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): """Called to handle a batch of replication data with a given stream token. Args: @@ -580,14 +575,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): token (int): stream token for this batch of rows rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. - - Returns: - Deferred|None """ raise NotImplementedError() @abc.abstractmethod - def on_position(self, stream_name, token): + async def on_position(self, stream_name, token): """Called when we get new position data.""" raise NotImplementedError() @@ -676,12 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): if not self.streams_connecting: self.handler.finished_connecting() - def on_SERVER(self, cmd): + async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") - def on_RDATA(self, cmd): + async def on_RDATA(self, cmd): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -701,19 +693,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Check if this is the last of a batch of updates rows = self.pending_batches.pop(stream_name, []) rows.append(row) - return self.handler.on_rdata(stream_name, cmd.token, rows) + await self.handler.on_rdata(stream_name, cmd.token, rows) - def on_POSITION(self, cmd): + async def on_POSITION(self, cmd): # When we get a `POSITION` command it means we've finished getting # missing updates for the given stream, and are now up to date. self.streams_connecting.discard(cmd.stream_name) if not self.streams_connecting: self.handler.finished_connecting() - return self.handler.on_position(cmd.stream_name, cmd.token) + await self.handler.on_position(cmd.stream_name, cmd.token) - def on_SYNC(self, cmd): - return self.handler.on_sync(cmd.data) + async def on_SYNC(self, cmd): + self.handler.on_sync(cmd.data) def replicate(self, stream_name, token): """Send the subscription request to the server diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index cbfdaf5773..b1752e88cd 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -23,7 +23,6 @@ from six import itervalues from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.protocol import Factory from synapse.metrics import LaterGauge @@ -155,8 +154,7 @@ class ReplicationStreamer(object): run_as_background_process("replication_notifier", self._run_notifier_loop) - @defer.inlineCallbacks - def _run_notifier_loop(self): + async def _run_notifier_loop(self): self.is_looping = True try: @@ -185,7 +183,7 @@ class ReplicationStreamer(object): continue if self._replication_torture_level: - yield self.clock.sleep( + await self.clock.sleep( self._replication_torture_level / 1000.0 ) @@ -196,7 +194,7 @@ class ReplicationStreamer(object): stream.upto_token, ) try: - updates, current_token = yield stream.get_updates() + updates, current_token = await stream.get_updates() except Exception: logger.info("Failed to handle stream %s", stream.NAME) raise @@ -233,7 +231,7 @@ class ReplicationStreamer(object): self.is_looping = False @measure_func("repl.get_stream_updates") - def get_stream_updates(self, stream_name, token): + async def get_stream_updates(self, stream_name, token): """For a given stream get all updates since token. This is called when a client first subscribes to a stream. """ @@ -241,7 +239,7 @@ class ReplicationStreamer(object): if not stream: raise Exception("unknown stream %s", stream_name) - return stream.get_updates_since(token) + return await stream.get_updates_since(token) @measure_func("repl.federation_ack") def federation_ack(self, token): @@ -252,22 +250,20 @@ class ReplicationStreamer(object): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") - @defer.inlineCallbacks - def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): + async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() - yield self.presence_handler.update_external_syncs_row( + await self.presence_handler.update_external_syncs_row( conn_id, user_id, is_syncing, last_sync_ms ) @measure_func("repl.on_remove_pusher") - @defer.inlineCallbacks - def on_remove_pusher(self, app_id, push_key, user_id): + async def on_remove_pusher(self, app_id, push_key, user_id): """A client has asked us to remove a pusher """ remove_pusher_counter.inc() - yield self.store.delete_pusher_by_app_id_pushkey_user_id( + await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id=app_id, pushkey=push_key, user_id=user_id ) @@ -281,15 +277,16 @@ class ReplicationStreamer(object): getattr(self.store, cache_func).invalidate(tuple(keys)) @measure_func("repl.on_user_ip") - @defer.inlineCallbacks - def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): + async def on_user_ip( + self, user_id, access_token, ip, user_agent, device_id, last_seen + ): """The client saw a user request """ user_ip_cache_counter.inc() - yield self.store.insert_client_ip( + await self.store.insert_client_ip( user_id, access_token, ip, user_agent, device_id, last_seen ) - yield self._server_notices_sender.on_user_ip(user_id) + await self._server_notices_sender.on_user_ip(user_id) def send_sync_to_all_connections(self, data): """Sends a SYNC command to all clients. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4ab0334fc1..e03e77199b 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -19,8 +19,6 @@ import logging from collections import namedtuple from typing import Any -from twisted.internet import defer - logger = logging.getLogger(__name__) @@ -144,8 +142,7 @@ class Stream(object): self.upto_token = self.current_token() self.last_token = self.upto_token - @defer.inlineCallbacks - def get_updates(self): + async def get_updates(self): """Gets all updates since the last time this function was called (or since the stream was constructed if it hadn't been called before), until the `upto_token` @@ -156,13 +153,12 @@ class Stream(object): list of ``(token, row)`` entries. ``row`` will be json-serialised and sent over the replication steam. """ - updates, current_token = yield self.get_updates_since(self.last_token) + updates, current_token = await self.get_updates_since(self.last_token) self.last_token = current_token return updates, current_token - @defer.inlineCallbacks - def get_updates_since(self, from_token): + async def get_updates_since(self, from_token): """Like get_updates except allows specifying from when we should stream updates @@ -182,15 +178,16 @@ class Stream(object): if from_token == current_token: return [], current_token + logger.info("get_updates_since: %s", self.__class__) if self._LIMITED: - rows = yield self.update_function( + rows = await self.update_function( from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 ) # never turn more than MAX_EVENTS_BEHIND + 1 into updates. rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) else: - rows = yield self.update_function(from_token, current_token) + rows = await self.update_function(from_token, current_token) updates = [(row[0], row[1:]) for row in rows] @@ -295,9 +292,8 @@ class PushRulesStream(Stream): push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit) + async def update_function(self, from_token, to_token, limit): + rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) return [(row[0], row[2]) for row in rows] @@ -413,9 +409,8 @@ class AccountDataStream(Stream): super(AccountDataStream, self).__init__(hs) - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - global_results, room_results = yield self.store.get_all_updated_account_data( + async def update_function(self, from_token, to_token, limit): + global_results, room_results = await self.store.get_all_updated_account_data( from_token, from_token, to_token, limit ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 0843e5aa90..b3afabb8cd 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,8 +19,6 @@ from typing import Tuple, Type import attr -from twisted.internet import defer - from ._base import Stream @@ -122,16 +120,15 @@ class EventsStream(Stream): super(EventsStream, self).__init__(hs) - @defer.inlineCallbacks - def update_function(self, from_token, current_token, limit=None): - event_rows = yield self._store.get_all_new_forward_event_rows( + async def update_function(self, from_token, current_token, limit=None): + event_rows = await self._store.get_all_new_forward_event_rows( from_token, current_token, limit ) event_updates = ( (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows ) - state_rows = yield self._store.get_all_updated_current_state_deltas( + state_rows = await self._store.get_all_updated_current_state_deltas( from_token, current_token, limit ) state_updates = ( diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 1d14e77255..e96ad4ca4e 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -73,6 +73,6 @@ class TestReplicationClientHandler(object): def finished_connecting(self): pass - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): for r in rows: self.received_rdata_rows.append((stream_name, token, r)) -- cgit 1.5.1 From 5d7a6ad2238981646b2ae7b4071d8715281d181a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 22 Jan 2020 10:37:00 +0000 Subject: Allow streaming cache invalidate all to workers. (#6749) --- changelog.d/6749.misc | 1 + docs/tcp_replication.md | 5 +++++ synapse/replication/slave/storage/_base.py | 7 ++++++- synapse/replication/tcp/streams/_base.py | 26 +++++++++++++++++++++----- synapse/storage/_base.py | 18 +++++++++++++----- synapse/storage/data_stores/main/cache.py | 27 +++++++++++++++++++++++---- 6 files changed, 69 insertions(+), 15 deletions(-) create mode 100644 changelog.d/6749.misc (limited to 'synapse/replication/tcp/streams') diff --git a/changelog.d/6749.misc b/changelog.d/6749.misc new file mode 100644 index 0000000000..9fa13cb1d4 --- /dev/null +++ b/changelog.d/6749.misc @@ -0,0 +1 @@ +Allow streaming cache 'invalidate all' to workers. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index a0b1d563ff..e3a4634b14 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -254,6 +254,11 @@ and they key to invalidate. For example: > RDATA caches 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251] +Alternatively, an entire cache can be invalidated by sending down a `null` +instead of the key. For example: + + > RDATA caches 550953772 ["get_user_by_id", null, 1550574873252] + However, there are times when a number of caches need to be invalidated at the same time with the same key. To reduce traffic we batch those invalidations into a single poke by defining a special cache name that diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 704282c800..f45cbd37a0 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -66,11 +66,16 @@ class BaseSlavedStore(SQLBaseStore): self._cache_id_gen.advance(token) for row in rows: if row.cache_func == CURRENT_STATE_CACHE_NAME: + if row.keys is None: + raise Exception( + "Can't send an 'invalidate all' for current state cache" + ) + room_id = row.keys[0] members_changed = set(row.keys[1:]) self._invalidate_state_caches(room_id, members_changed) else: - self._attempt_to_invalidate_cache(row.cache_func, tuple(row.keys)) + self._attempt_to_invalidate_cache(row.cache_func, row.keys) def _invalidate_cache_and_stream(self, txn, cache_func, keys): txn.call_after(cache_func.invalidate, keys) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index e03e77199b..a8d568b14a 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -17,7 +17,9 @@ import itertools import logging from collections import namedtuple -from typing import Any +from typing import Any, List, Optional + +import attr logger = logging.getLogger(__name__) @@ -65,10 +67,24 @@ PushersStreamRow = namedtuple( "PushersStreamRow", ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool ) -CachesStreamRow = namedtuple( - "CachesStreamRow", - ("cache_func", "keys", "invalidation_ts"), # str # list(str) # int -) + + +@attr.s +class CachesStreamRow: + """Stream to inform workers they should invalidate their cache. + + Attributes: + cache_func: Name of the cached function. + keys: The entry in the cache to invalidate. If None then will + invalidate all. + invalidation_ts: Timestamp of when the invalidation took place. + """ + + cache_func = attr.ib(type=str) + keys = attr.ib(type=Optional[List[Any]]) + invalidation_ts = attr.ib(type=int) + + PublicRoomsStreamRow = namedtuple( "PublicRoomsStreamRow", ( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 3bb9381663..da3b99f93d 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,6 +17,7 @@ import logging import random from abc import ABCMeta +from typing import Any, Optional from six import PY2 from six.moves import builtins @@ -26,7 +27,7 @@ from canonicaljson import json from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import Database -from synapse.types import get_domain_from_id +from synapse.types import Collection, get_domain_from_id logger = logging.getLogger(__name__) @@ -63,17 +64,24 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,)) - def _attempt_to_invalidate_cache(self, cache_name, key): + def _attempt_to_invalidate_cache( + self, cache_name: str, key: Optional[Collection[Any]] + ): """Attempts to invalidate the cache of the given name, ignoring if the cache doesn't exist. Mainly used for invalidating caches on workers, where they may not have the cache. Args: - cache_name (str) - key (tuple) + cache_name + key: Entry to invalidate. If None then invalidates the entire + cache. """ + try: - getattr(self, cache_name).invalidate(key) + if key is None: + getattr(self, cache_name).invalidate_all() + else: + getattr(self, cache_name).invalidate(tuple(key)) except AttributeError: # We probably haven't pulled in the cache in this worker, # which is fine. diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index bf91512daf..afa2b41c98 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -16,6 +16,7 @@ import itertools import logging +from typing import Any, Iterable, Optional from twisted.internet import defer @@ -43,6 +44,14 @@ class CacheInvalidationStore(SQLBaseStore): txn.call_after(cache_func.invalidate, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys) + def _invalidate_all_cache_and_stream(self, txn, cache_func): + """Invalidates the entire cache and adds it to the cache stream so slaves + will know to invalidate their caches. + """ + + txn.call_after(cache_func.invalidate_all) + self._send_invalidation_to_replication(txn, cache_func.__name__, None) + def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): """Special case invalidation of caches based on current state. @@ -73,17 +82,24 @@ class CacheInvalidationStore(SQLBaseStore): txn, CURRENT_STATE_CACHE_NAME, [room_id] ) - def _send_invalidation_to_replication(self, txn, cache_name, keys): + def _send_invalidation_to_replication( + self, txn, cache_name: str, keys: Optional[Iterable[Any]] + ): """Notifies replication that given cache has been invalidated. Note that this does *not* invalidate the cache locally. Args: txn - cache_name (str) - keys (iterable[str]) + cache_name + keys: Entry to invalidate. If None will invalidate all. """ + if cache_name == CURRENT_STATE_CACHE_NAME and keys is None: + raise Exception( + "Can't stream invalidate all with magic current state cache" + ) + if isinstance(self.database_engine, PostgresEngine): # get_next() returns a context manager which is designed to wrap # the transaction. However, we want to only get an ID when we want @@ -95,13 +111,16 @@ class CacheInvalidationStore(SQLBaseStore): txn.call_after(ctx.__exit__, None, None, None) txn.call_after(self.hs.get_notifier().on_new_replication_data) + if keys is not None: + keys = list(keys) + self.db.simple_insert_txn( txn, table="cache_invalidation_stream", values={ "stream_id": stream_id, "cache_func": cache_name, - "keys": list(keys), + "keys": keys, "invalidation_ts": self.clock.time_msec(), }, ) -- cgit 1.5.1