From 48c3a96886de64f3141ad68b8163cd2fc0c197ff Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 Jan 2020 09:16:12 +0000 Subject: Port synapse.replication.tcp to async/await (#6666) * Port synapse.replication.tcp to async/await * Newsfile * Correctly document type of on_ functions as async * Don't be overenthusiastic with the asyncing.... --- synapse/replication/tcp/client.py | 11 ++--- synapse/replication/tcp/protocol.py | 72 ++++++++++++++----------------- synapse/replication/tcp/resource.py | 31 ++++++------- synapse/replication/tcp/streams/_base.py | 25 +++++------ synapse/replication/tcp/streams/events.py | 9 ++-- 5 files changed, 63 insertions(+), 85 deletions(-) (limited to 'synapse/replication/tcp') diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index aa7fd90e26..52a0aefe68 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -110,7 +110,7 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self.factory) - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to @@ -121,20 +121,17 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): token (int): stream token for this batch of rows rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. - - Returns: - Deferred|None """ logger.debug("Received rdata %s -> %s", stream_name, token) - return self.store.process_replication_rows(stream_name, token, rows) + self.store.process_replication_rows(stream_name, token, rows) - def on_position(self, stream_name, token): + async def on_position(self, stream_name, token): """Called when we get new position data. By default this just pokes the slave store. Can be overriden in subclasses to handle more. """ - return self.store.process_replication_rows(stream_name, token, []) + self.store.process_replication_rows(stream_name, token, []) def on_sync(self, data): """When we received a SYNC we wake up any deferreds that were waiting diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index db0353c996..5f4bdf84d2 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -81,12 +81,11 @@ from synapse.replication.tcp.commands import ( SyncCommand, UserSyncCommand, ) +from synapse.replication.tcp.streams import STREAMS_MAP from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string -from .streams import STREAMS_MAP - connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) @@ -241,19 +240,16 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd ) - def handle_command(self, cmd): + async def handle_command(self, cmd: Command): """Handle a command we have received over the replication stream. - By default delegates to on_ + By default delegates to on_, which should return an awaitable. Args: - cmd (synapse.replication.tcp.commands.Command): received command - - Returns: - Deferred + cmd: received command """ handler = getattr(self, "on_%s" % (cmd.NAME,)) - return handler(cmd) + await handler(cmd) def close(self): logger.warning("[%s] Closing connection", self.id()) @@ -326,10 +322,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): for cmd in pending: self.send_command(cmd) - def on_PING(self, line): + async def on_PING(self, line): self.received_ping = True - def on_ERROR(self, cmd): + async def on_ERROR(self, cmd): logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) def pauseProducing(self): @@ -429,16 +425,16 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): BaseReplicationStreamProtocol.connectionMade(self) self.streamer.new_connection(self) - def on_NAME(self, cmd): + async def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data - def on_USER_SYNC(self, cmd): - return self.streamer.on_user_sync( + async def on_USER_SYNC(self, cmd): + await self.streamer.on_user_sync( self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) - def on_REPLICATE(self, cmd): + async def on_REPLICATE(self, cmd): stream_name = cmd.stream_name token = cmd.token @@ -449,23 +445,23 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): for stream in iterkeys(self.streamer.streams_by_name) ] - return make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) else: - return self.subscribe_to_stream(stream_name, token) + await self.subscribe_to_stream(stream_name, token) - def on_FEDERATION_ACK(self, cmd): - return self.streamer.federation_ack(cmd.token) + async def on_FEDERATION_ACK(self, cmd): + self.streamer.federation_ack(cmd.token) - def on_REMOVE_PUSHER(self, cmd): - return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) + async def on_REMOVE_PUSHER(self, cmd): + await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) - def on_INVALIDATE_CACHE(self, cmd): - return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + async def on_INVALIDATE_CACHE(self, cmd): + self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) - def on_USER_IP(self, cmd): - return self.streamer.on_user_ip( + async def on_USER_IP(self, cmd): + self.streamer.on_user_ip( cmd.user_id, cmd.access_token, cmd.ip, @@ -474,8 +470,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): cmd.last_seen, ) - @defer.inlineCallbacks - def subscribe_to_stream(self, stream_name, token): + async def subscribe_to_stream(self, stream_name, token): """Subscribe the remote to a stream. This invloves checking if they've missed anything and sending those @@ -487,7 +482,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): try: # Get missing updates - updates, current_token = yield self.streamer.get_stream_updates( + updates, current_token = await self.streamer.get_stream_updates( stream_name, token ) @@ -572,7 +567,7 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): """Called to handle a batch of replication data with a given stream token. Args: @@ -580,14 +575,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): token (int): stream token for this batch of rows rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. - - Returns: - Deferred|None """ raise NotImplementedError() @abc.abstractmethod - def on_position(self, stream_name, token): + async def on_position(self, stream_name, token): """Called when we get new position data.""" raise NotImplementedError() @@ -676,12 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): if not self.streams_connecting: self.handler.finished_connecting() - def on_SERVER(self, cmd): + async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") - def on_RDATA(self, cmd): + async def on_RDATA(self, cmd): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -701,19 +693,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Check if this is the last of a batch of updates rows = self.pending_batches.pop(stream_name, []) rows.append(row) - return self.handler.on_rdata(stream_name, cmd.token, rows) + await self.handler.on_rdata(stream_name, cmd.token, rows) - def on_POSITION(self, cmd): + async def on_POSITION(self, cmd): # When we get a `POSITION` command it means we've finished getting # missing updates for the given stream, and are now up to date. self.streams_connecting.discard(cmd.stream_name) if not self.streams_connecting: self.handler.finished_connecting() - return self.handler.on_position(cmd.stream_name, cmd.token) + await self.handler.on_position(cmd.stream_name, cmd.token) - def on_SYNC(self, cmd): - return self.handler.on_sync(cmd.data) + async def on_SYNC(self, cmd): + self.handler.on_sync(cmd.data) def replicate(self, stream_name, token): """Send the subscription request to the server diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index cbfdaf5773..b1752e88cd 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -23,7 +23,6 @@ from six import itervalues from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.protocol import Factory from synapse.metrics import LaterGauge @@ -155,8 +154,7 @@ class ReplicationStreamer(object): run_as_background_process("replication_notifier", self._run_notifier_loop) - @defer.inlineCallbacks - def _run_notifier_loop(self): + async def _run_notifier_loop(self): self.is_looping = True try: @@ -185,7 +183,7 @@ class ReplicationStreamer(object): continue if self._replication_torture_level: - yield self.clock.sleep( + await self.clock.sleep( self._replication_torture_level / 1000.0 ) @@ -196,7 +194,7 @@ class ReplicationStreamer(object): stream.upto_token, ) try: - updates, current_token = yield stream.get_updates() + updates, current_token = await stream.get_updates() except Exception: logger.info("Failed to handle stream %s", stream.NAME) raise @@ -233,7 +231,7 @@ class ReplicationStreamer(object): self.is_looping = False @measure_func("repl.get_stream_updates") - def get_stream_updates(self, stream_name, token): + async def get_stream_updates(self, stream_name, token): """For a given stream get all updates since token. This is called when a client first subscribes to a stream. """ @@ -241,7 +239,7 @@ class ReplicationStreamer(object): if not stream: raise Exception("unknown stream %s", stream_name) - return stream.get_updates_since(token) + return await stream.get_updates_since(token) @measure_func("repl.federation_ack") def federation_ack(self, token): @@ -252,22 +250,20 @@ class ReplicationStreamer(object): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") - @defer.inlineCallbacks - def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): + async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() - yield self.presence_handler.update_external_syncs_row( + await self.presence_handler.update_external_syncs_row( conn_id, user_id, is_syncing, last_sync_ms ) @measure_func("repl.on_remove_pusher") - @defer.inlineCallbacks - def on_remove_pusher(self, app_id, push_key, user_id): + async def on_remove_pusher(self, app_id, push_key, user_id): """A client has asked us to remove a pusher """ remove_pusher_counter.inc() - yield self.store.delete_pusher_by_app_id_pushkey_user_id( + await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id=app_id, pushkey=push_key, user_id=user_id ) @@ -281,15 +277,16 @@ class ReplicationStreamer(object): getattr(self.store, cache_func).invalidate(tuple(keys)) @measure_func("repl.on_user_ip") - @defer.inlineCallbacks - def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): + async def on_user_ip( + self, user_id, access_token, ip, user_agent, device_id, last_seen + ): """The client saw a user request """ user_ip_cache_counter.inc() - yield self.store.insert_client_ip( + await self.store.insert_client_ip( user_id, access_token, ip, user_agent, device_id, last_seen ) - yield self._server_notices_sender.on_user_ip(user_id) + await self._server_notices_sender.on_user_ip(user_id) def send_sync_to_all_connections(self, data): """Sends a SYNC command to all clients. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4ab0334fc1..e03e77199b 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -19,8 +19,6 @@ import logging from collections import namedtuple from typing import Any -from twisted.internet import defer - logger = logging.getLogger(__name__) @@ -144,8 +142,7 @@ class Stream(object): self.upto_token = self.current_token() self.last_token = self.upto_token - @defer.inlineCallbacks - def get_updates(self): + async def get_updates(self): """Gets all updates since the last time this function was called (or since the stream was constructed if it hadn't been called before), until the `upto_token` @@ -156,13 +153,12 @@ class Stream(object): list of ``(token, row)`` entries. ``row`` will be json-serialised and sent over the replication steam. """ - updates, current_token = yield self.get_updates_since(self.last_token) + updates, current_token = await self.get_updates_since(self.last_token) self.last_token = current_token return updates, current_token - @defer.inlineCallbacks - def get_updates_since(self, from_token): + async def get_updates_since(self, from_token): """Like get_updates except allows specifying from when we should stream updates @@ -182,15 +178,16 @@ class Stream(object): if from_token == current_token: return [], current_token + logger.info("get_updates_since: %s", self.__class__) if self._LIMITED: - rows = yield self.update_function( + rows = await self.update_function( from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 ) # never turn more than MAX_EVENTS_BEHIND + 1 into updates. rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) else: - rows = yield self.update_function(from_token, current_token) + rows = await self.update_function(from_token, current_token) updates = [(row[0], row[1:]) for row in rows] @@ -295,9 +292,8 @@ class PushRulesStream(Stream): push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit) + async def update_function(self, from_token, to_token, limit): + rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) return [(row[0], row[2]) for row in rows] @@ -413,9 +409,8 @@ class AccountDataStream(Stream): super(AccountDataStream, self).__init__(hs) - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - global_results, room_results = yield self.store.get_all_updated_account_data( + async def update_function(self, from_token, to_token, limit): + global_results, room_results = await self.store.get_all_updated_account_data( from_token, from_token, to_token, limit ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 0843e5aa90..b3afabb8cd 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,8 +19,6 @@ from typing import Tuple, Type import attr -from twisted.internet import defer - from ._base import Stream @@ -122,16 +120,15 @@ class EventsStream(Stream): super(EventsStream, self).__init__(hs) - @defer.inlineCallbacks - def update_function(self, from_token, current_token, limit=None): - event_rows = yield self._store.get_all_new_forward_event_rows( + async def update_function(self, from_token, current_token, limit=None): + event_rows = await self._store.get_all_new_forward_event_rows( from_token, current_token, limit ) event_updates = ( (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows ) - state_rows = yield self._store.get_all_updated_current_state_deltas( + state_rows = await self._store.get_all_updated_current_state_deltas( from_token, current_token, limit ) state_updates = ( -- cgit 1.4.1