From 32e7c9e7f20b57dd081023ac42d6931a8da9b3a3 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 20 Jun 2019 19:32:02 +1000 Subject: Run Black. (#5482) --- synapse/replication/tcp/resource.py | 54 +++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 20 deletions(-) (limited to 'synapse/replication/tcp/resource.py') diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index f6a38f5140..d1e98428bc 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -33,13 +33,15 @@ from .protocol import ServerReplicationStreamProtocol from .streams import STREAMS_MAP from .streams.federation import FederationStream -stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates", - "", ["stream_name"]) +stream_updates_counter = Counter( + "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] +) user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache", - "") +invalidate_cache_counter = Counter( + "synapse_replication_tcp_resource_invalidate_cache", "" +) user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") logger = logging.getLogger(__name__) @@ -48,6 +50,7 @@ logger = logging.getLogger(__name__) class ReplicationStreamProtocolFactory(Factory): """Factory for new replication connections. """ + def __init__(self, hs): self.streamer = ReplicationStreamer(hs) self.clock = hs.get_clock() @@ -55,9 +58,7 @@ class ReplicationStreamProtocolFactory(Factory): def buildProtocol(self, addr): return ServerReplicationStreamProtocol( - self.server_name, - self.clock, - self.streamer, + self.server_name, self.clock, self.streamer ) @@ -80,29 +81,39 @@ class ReplicationStreamer(object): # Current connections. self.connections = [] - LaterGauge("synapse_replication_tcp_resource_total_connections", "", [], - lambda: len(self.connections)) + LaterGauge( + "synapse_replication_tcp_resource_total_connections", + "", + [], + lambda: len(self.connections), + ) # List of streams that clients can subscribe to. # We only support federation stream if federation sending hase been # disabled on the master. self.streams = [ - stream(hs) for stream in itervalues(STREAMS_MAP) + stream(hs) + for stream in itervalues(STREAMS_MAP) if stream != FederationStream or not hs.config.send_federation ] self.streams_by_name = {stream.NAME: stream for stream in self.streams} LaterGauge( - "synapse_replication_tcp_resource_connections_per_stream", "", + "synapse_replication_tcp_resource_connections_per_stream", + "", ["stream_name"], lambda: { - (stream_name,): len([ - conn for conn in self.connections - if stream_name in conn.replication_streams - ]) + (stream_name,): len( + [ + conn + for conn in self.connections + if stream_name in conn.replication_streams + ] + ) for stream_name in self.streams_by_name - }) + }, + ) self.federation_sender = None if not hs.config.send_federation: @@ -179,7 +190,9 @@ class ReplicationStreamer(object): logger.debug( "Getting stream: %s: %s -> %s", - stream.NAME, stream.last_token, stream.upto_token + stream.NAME, + stream.last_token, + stream.upto_token, ) try: updates, current_token = yield stream.get_updates() @@ -189,7 +202,8 @@ class ReplicationStreamer(object): logger.debug( "Sending %d updates to %d connections", - len(updates), len(self.connections), + len(updates), + len(self.connections), ) if updates: @@ -243,7 +257,7 @@ class ReplicationStreamer(object): """ user_sync_counter.inc() yield self.presence_handler.update_external_syncs_row( - conn_id, user_id, is_syncing, last_sync_ms, + conn_id, user_id, is_syncing, last_sync_ms ) @measure_func("repl.on_remove_pusher") @@ -272,7 +286,7 @@ class ReplicationStreamer(object): """ user_ip_cache_counter.inc() yield self.store.insert_client_ip( - user_id, access_token, ip, user_agent, device_id, last_seen, + user_id, access_token, ip, user_agent, device_id, last_seen ) yield self._server_notices_sender.on_user_ip(user_id) -- cgit 1.5.1 From e8b68a4e4b439065536c281d8997af85880f6ee2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 14 Jan 2020 14:08:06 +0000 Subject: Fixup synapse.replication to pass mypy checks (#6667) --- changelog.d/6667.misc | 1 + synapse/replication/http/_base.py | 10 ++--- synapse/replication/slave/storage/_base.py | 7 ++-- synapse/replication/slave/storage/presence.py | 2 +- synapse/replication/tcp/client.py | 12 +++--- synapse/replication/tcp/commands.py | 42 ++++++++++---------- synapse/replication/tcp/protocol.py | 36 ++++++++++------- synapse/replication/tcp/resource.py | 3 +- synapse/replication/tcp/streams/_base.py | 57 ++++++++++++++------------- synapse/replication/tcp/streams/events.py | 16 +++++--- synapse/replication/tcp/streams/federation.py | 4 +- tox.ini | 1 + 12 files changed, 105 insertions(+), 86 deletions(-) create mode 100644 changelog.d/6667.misc (limited to 'synapse/replication/tcp/resource.py') diff --git a/changelog.d/6667.misc b/changelog.d/6667.misc new file mode 100644 index 0000000000..227f80a508 --- /dev/null +++ b/changelog.d/6667.misc @@ -0,0 +1 @@ +Fixup `synapse.replication` to pass mypy checks. diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index c8056b0c0c..444eb7b7f4 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -16,6 +16,7 @@ import abc import logging import re +from typing import Dict, List, Tuple from six import raise_from from six.moves import urllib @@ -78,9 +79,8 @@ class ReplicationEndpoint(object): __metaclass__ = abc.ABCMeta - NAME = abc.abstractproperty() - PATH_ARGS = abc.abstractproperty() - + NAME = abc.abstractproperty() # type: str # type: ignore + PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore METHOD = "POST" CACHE = True RETRY_ON_TIMEOUT = True @@ -171,7 +171,7 @@ class ReplicationEndpoint(object): # have a good idea that the request has either succeeded or failed on # the master, and so whether we should clean up or not. while True: - headers = {} + headers = {} # type: Dict[bytes, List[bytes]] inject_active_span_byte_dict(headers, None, check_destination=False) try: result = yield request_func(uri, data, headers=headers) @@ -207,7 +207,7 @@ class ReplicationEndpoint(object): method = self.METHOD if self.CACHE: - handler = self._cached_handler + handler = self._cached_handler # type: ignore url_args.append("txn_id") args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index b91a528245..704282c800 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict +from typing import Dict, Optional import six @@ -41,7 +41,7 @@ class BaseSlavedStore(SQLBaseStore): if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( db_conn, "cache_invalidation_stream", "stream_id" - ) + ) # type: Optional[SlavedIdTracker] else: self._cache_id_gen = None @@ -62,7 +62,8 @@ class BaseSlavedStore(SQLBaseStore): def process_replication_rows(self, stream_name, token, rows): if stream_name == "caches": - self._cache_id_gen.advance(token) + if self._cache_id_gen: + self._cache_id_gen.advance(token) for row in rows: if row.cache_func == CURRENT_STATE_CACHE_NAME: room_id = row.keys[0] diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index f552e7c972..ad8f0c15a9 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -29,7 +29,7 @@ class SlavedPresenceStore(BaseSlavedStore): self._presence_on_startup = self._get_active_presence(db_conn) - self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache( + self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", self._presence_id_gen.get_current_token() ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index bbcb84646c..aa7fd90e26 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,7 +16,7 @@ """ import logging -from typing import Dict +from typing import Dict, List, Optional from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory @@ -28,6 +28,7 @@ from synapse.replication.tcp.protocol import ( ) from .commands import ( + Command, FederationAckCommand, InvalidateCacheCommand, RemovePusherCommand, @@ -89,15 +90,15 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): # Any pending commands to be sent once a new connection has been # established - self.pending_commands = [] + self.pending_commands = [] # type: List[Command] # Map from string -> deferred, to wake up when receiveing a SYNC with # the given string. # Used for tests. - self.awaiting_syncs = {} + self.awaiting_syncs = {} # type: Dict[str, defer.Deferred] # The factory used to create connections. - self.factory = None + self.factory = None # type: Optional[ReplicationClientFactory] def start_replication(self, hs): """Helper method to start a replication connection to the remote server @@ -235,4 +236,5 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): # We don't reset the delay any earlier as otherwise if there is a # problem during start up we'll end up tight looping connecting to the # server. - self.factory.resetDelay() + if self.factory: + self.factory.resetDelay() diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 0ff2a7199f..cbb36b9acf 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -20,15 +20,16 @@ allowed to be sent by which side. import logging import platform +from typing import Tuple, Type if platform.python_implementation() == "PyPy": import json _json_encoder = json.JSONEncoder() else: - import simplejson as json + import simplejson as json # type: ignore[no-redef] # noqa: F821 - _json_encoder = json.JSONEncoder(namedtuple_as_object=False) + _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821 logger = logging.getLogger(__name__) @@ -44,7 +45,7 @@ class Command(object): The default implementation creates a command of form ` ` """ - NAME = None + NAME = None # type: str def __init__(self, data): self.data = data @@ -386,25 +387,24 @@ class UserIpCommand(Command): ) +_COMMANDS = ( + ServerCommand, + RdataCommand, + PositionCommand, + ErrorCommand, + PingCommand, + NameCommand, + ReplicateCommand, + UserSyncCommand, + FederationAckCommand, + SyncCommand, + RemovePusherCommand, + InvalidateCacheCommand, + UserIpCommand, +) # type: Tuple[Type[Command], ...] + # Map of command name to command type. -COMMAND_MAP = { - cmd.NAME: cmd - for cmd in ( - ServerCommand, - RdataCommand, - PositionCommand, - ErrorCommand, - PingCommand, - NameCommand, - ReplicateCommand, - UserSyncCommand, - FederationAckCommand, - SyncCommand, - RemovePusherCommand, - InvalidateCacheCommand, - UserIpCommand, - ) -} +COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS} # The commands the server is allowed to send VALID_SERVER_COMMANDS = ( diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index afaf002fe6..db0353c996 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -53,6 +53,7 @@ import fcntl import logging import struct from collections import defaultdict +from typing import Any, DefaultDict, Dict, List, Set, Tuple from six import iteritems, iterkeys @@ -65,13 +66,11 @@ from twisted.python.failure import Failure from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util import Clock -from synapse.util.stringutils import random_string - -from .commands import ( +from synapse.replication.tcp.commands import ( COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, + Command, ErrorCommand, NameCommand, PingCommand, @@ -82,6 +81,10 @@ from .commands import ( SyncCommand, UserSyncCommand, ) +from synapse.types import Collection +from synapse.util import Clock +from synapse.util.stringutils import random_string + from .streams import STREAMS_MAP connection_close_counter = Counter( @@ -124,8 +127,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): delimiter = b"\n" - VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive - VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send + # Valid commands we expect to receive + VALID_INBOUND_COMMANDS = [] # type: Collection[str] + + # Valid commands we can send + VALID_OUTBOUND_COMMANDS = [] # type: Collection[str] max_line_buffer = 10000 @@ -144,13 +150,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.conn_id = random_string(5) # To dedupe in case of name clashes. # List of pending commands to send once we've established the connection - self.pending_commands = [] + self.pending_commands = [] # type: List[Command] # The LoopingCall for sending pings. self._send_ping_loop = None - self.inbound_commands_counter = defaultdict(int) - self.outbound_commands_counter = defaultdict(int) + self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int] + self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int] def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -409,14 +415,14 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.streamer = streamer # The streams the client has subscribed to and is up to date with - self.replication_streams = set() + self.replication_streams = set() # type: Set[str] # The streams the client is currently subscribing to. - self.connecting_streams = set() + self.connecting_streams = set() # type: Set[str] # Map from stream name to list of updates to send once we've finished # subscribing the client to the stream. - self.pending_rdata = {} + self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]] def connectionMade(self): self.send_command(ServerCommand(self.server_name)) @@ -642,11 +648,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Set of stream names that have been subscribe to, but haven't yet # caught up with. This is used to track when the client has been fully # connected to the remote. - self.streams_connecting = set() + self.streams_connecting = set() # type: Set[str] # Map of stream to batched updates. See RdataCommand for info on how # batching works. - self.pending_batches = {} + self.pending_batches = {} # type: Dict[str, Any] def connectionMade(self): self.send_command(NameCommand(self.client_name)) @@ -766,7 +772,7 @@ def transport_kernel_read_buffer_size(protocol, read=True): op = SIOCINQ else: op = SIOCOUTQ - size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0] + size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0] return size return 0 diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index d1e98428bc..cbfdaf5773 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -17,6 +17,7 @@ import logging import random +from typing import List from six import itervalues @@ -79,7 +80,7 @@ class ReplicationStreamer(object): self._replication_torture_level = hs.config.replication_torture_level # Current connections. - self.connections = [] + self.connections = [] # type: List[ServerReplicationStreamProtocol] LaterGauge( "synapse_replication_tcp_resource_total_connections", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 8512923eae..4ab0334fc1 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,10 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. - import itertools import logging from collections import namedtuple +from typing import Any from twisted.internet import defer @@ -104,8 +104,9 @@ class Stream(object): time it was called up until the point `advance_current_token` was called. """ - NAME = None # The name of the stream - ROW_TYPE = None # The type of the row. Used by the default impl of parse_row. + NAME = None # type: str # The name of the stream + # The type of the row. Used by the default impl of parse_row. + ROW_TYPE = None # type: Any _LIMITED = True # Whether the update function takes a limit @classmethod @@ -231,8 +232,8 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_current_backfill_token - self.update_function = store.get_all_new_backfill_event_rows + self.current_token = store.get_current_backfill_token # type: ignore + self.update_function = store.get_all_new_backfill_event_rows # type: ignore super(BackfillStream, self).__init__(hs) @@ -246,8 +247,8 @@ class PresenceStream(Stream): store = hs.get_datastore() presence_handler = hs.get_presence_handler() - self.current_token = store.get_current_presence_token - self.update_function = presence_handler.get_all_presence_updates + self.current_token = store.get_current_presence_token # type: ignore + self.update_function = presence_handler.get_all_presence_updates # type: ignore super(PresenceStream, self).__init__(hs) @@ -260,8 +261,8 @@ class TypingStream(Stream): def __init__(self, hs): typing_handler = hs.get_typing_handler() - self.current_token = typing_handler.get_current_token - self.update_function = typing_handler.get_all_typing_updates + self.current_token = typing_handler.get_current_token # type: ignore + self.update_function = typing_handler.get_all_typing_updates # type: ignore super(TypingStream, self).__init__(hs) @@ -273,8 +274,8 @@ class ReceiptsStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_max_receipt_stream_id - self.update_function = store.get_all_updated_receipts + self.current_token = store.get_max_receipt_stream_id # type: ignore + self.update_function = store.get_all_updated_receipts # type: ignore super(ReceiptsStream, self).__init__(hs) @@ -310,8 +311,8 @@ class PushersStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_pushers_stream_token - self.update_function = store.get_all_updated_pushers_rows + self.current_token = store.get_pushers_stream_token # type: ignore + self.update_function = store.get_all_updated_pushers_rows # type: ignore super(PushersStream, self).__init__(hs) @@ -327,8 +328,8 @@ class CachesStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_cache_stream_token - self.update_function = store.get_all_updated_caches + self.current_token = store.get_cache_stream_token # type: ignore + self.update_function = store.get_all_updated_caches # type: ignore super(CachesStream, self).__init__(hs) @@ -343,8 +344,8 @@ class PublicRoomsStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_current_public_room_stream_id - self.update_function = store.get_all_new_public_rooms + self.current_token = store.get_current_public_room_stream_id # type: ignore + self.update_function = store.get_all_new_public_rooms # type: ignore super(PublicRoomsStream, self).__init__(hs) @@ -360,8 +361,8 @@ class DeviceListsStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_device_stream_token - self.update_function = store.get_all_device_list_changes_for_remotes + self.current_token = store.get_device_stream_token # type: ignore + self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore super(DeviceListsStream, self).__init__(hs) @@ -376,8 +377,8 @@ class ToDeviceStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_to_device_stream_token - self.update_function = store.get_all_new_device_messages + self.current_token = store.get_to_device_stream_token # type: ignore + self.update_function = store.get_all_new_device_messages # type: ignore super(ToDeviceStream, self).__init__(hs) @@ -392,8 +393,8 @@ class TagAccountDataStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_max_account_data_stream_id - self.update_function = store.get_all_updated_tags + self.current_token = store.get_max_account_data_stream_id # type: ignore + self.update_function = store.get_all_updated_tags # type: ignore super(TagAccountDataStream, self).__init__(hs) @@ -408,7 +409,7 @@ class AccountDataStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() - self.current_token = self.store.get_max_account_data_stream_id + self.current_token = self.store.get_max_account_data_stream_id # type: ignore super(AccountDataStream, self).__init__(hs) @@ -434,8 +435,8 @@ class GroupServerStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_group_stream_token - self.update_function = store.get_all_groups_changes + self.current_token = store.get_group_stream_token # type: ignore + self.update_function = store.get_all_groups_changes # type: ignore super(GroupServerStream, self).__init__(hs) @@ -451,7 +452,7 @@ class UserSignatureStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_device_stream_token - self.update_function = store.get_all_user_signature_changes_for_remotes + self.current_token = store.get_device_stream_token # type: ignore + self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore super(UserSignatureStream, self).__init__(hs) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index d97669c886..0843e5aa90 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -13,7 +13,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import heapq +from typing import Tuple, Type import attr @@ -63,7 +65,8 @@ class BaseEventsStreamRow(object): Specifies how to identify, serialize and deserialize the different types. """ - TypeId = None # Unique string that ids the type. Must be overriden in sub classes. + # Unique string that ids the type. Must be overriden in sub classes. + TypeId = None # type: str @classmethod def from_data(cls, data): @@ -99,9 +102,12 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow): event_id = attr.ib() # str, optional -TypeToRow = { - Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow) -} +_EventRows = ( + EventsStreamEventRow, + EventsStreamCurrentStateRow, +) # type: Tuple[Type[BaseEventsStreamRow], ...] + +TypeToRow = {Row.TypeId: Row for Row in _EventRows} class EventsStream(Stream): @@ -112,7 +118,7 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() - self.current_token = self._store.get_current_events_token + self.current_token = self._store.get_current_events_token # type: ignore super(EventsStream, self).__init__(hs) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index dc2484109d..615f3dc9ac 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -37,7 +37,7 @@ class FederationStream(Stream): def __init__(self, hs): federation_sender = hs.get_federation_sender() - self.current_token = federation_sender.get_current_token - self.update_function = federation_sender.get_replication_rows + self.current_token = federation_sender.get_current_token # type: ignore + self.update_function = federation_sender.get_replication_rows # type: ignore super(FederationStream, self).__init__(hs) diff --git a/tox.ini b/tox.ini index 0ab6d5666b..b73a993053 100644 --- a/tox.ini +++ b/tox.ini @@ -181,6 +181,7 @@ commands = mypy \ synapse/handlers/ui_auth \ synapse/logging/ \ synapse/module_api \ + synapse/replication \ synapse/rest/consent \ synapse/rest/saml2 \ synapse/spam_checker_api \ -- cgit 1.5.1 From 48c3a96886de64f3141ad68b8163cd2fc0c197ff Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 Jan 2020 09:16:12 +0000 Subject: Port synapse.replication.tcp to async/await (#6666) * Port synapse.replication.tcp to async/await * Newsfile * Correctly document type of on_ functions as async * Don't be overenthusiastic with the asyncing.... --- changelog.d/6666.misc | 1 + synapse/app/admin_cmd.py | 3 +- synapse/app/appservice.py | 5 +-- synapse/app/federation_sender.py | 5 +-- synapse/app/pusher.py | 5 +-- synapse/app/synchrotron.py | 5 +-- synapse/app/user_dir.py | 5 +-- synapse/federation/send_queue.py | 4 +- synapse/handlers/typing.py | 2 +- synapse/replication/tcp/client.py | 11 ++--- synapse/replication/tcp/protocol.py | 72 ++++++++++++++----------------- synapse/replication/tcp/resource.py | 31 ++++++------- synapse/replication/tcp/streams/_base.py | 25 +++++------ synapse/replication/tcp/streams/events.py | 9 ++-- tests/replication/tcp/streams/_base.py | 2 +- 15 files changed, 80 insertions(+), 105 deletions(-) create mode 100644 changelog.d/6666.misc (limited to 'synapse/replication/tcp/resource.py') diff --git a/changelog.d/6666.misc b/changelog.d/6666.misc new file mode 100644 index 0000000000..e79c23d2d2 --- /dev/null +++ b/changelog.d/6666.misc @@ -0,0 +1 @@ +Port `synapse.replication.tcp` to async/await. diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 8e36bc57d3..1c7c6ec0c8 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -84,8 +84,7 @@ class AdminCmdServer(HomeServer): class AdminCmdReplicationHandler(ReplicationClientHandler): - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): pass def get_streams_to_replicate(self): diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index e82e0f11e3..2217d4a4fb 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -115,9 +115,8 @@ class ASReplicationHandler(ReplicationClientHandler): super(ASReplicationHandler, self).__init__(hs.get_datastore()) self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, token, rows): + await super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) if stream_name == "events": max_stream_id = self.store.get_room_max_stream_ordering() diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 83c436229c..a57cf991ac 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -145,9 +145,8 @@ class FederationSenderReplicationHandler(ReplicationClientHandler): super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) self.send_handler = FederationSenderHandler(hs, self) - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(FederationSenderReplicationHandler, self).on_rdata( + async def on_rdata(self, stream_name, token, rows): + await super(FederationSenderReplicationHandler, self).on_rdata( stream_name, token, rows ) self.send_handler.process_replication_rows(stream_name, token, rows) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 09e639040a..e46b6ac598 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -141,9 +141,8 @@ class PusherReplicationHandler(ReplicationClientHandler): self.pusher_pool = hs.get_pusherpool() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, token, rows): + await super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) run_in_background(self.poke_pushers, stream_name, token, rows) @defer.inlineCallbacks diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 03031ee34d..3218da07bd 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -358,9 +358,8 @@ class SyncReplicationHandler(ReplicationClientHandler): self.presence_handler = hs.get_presence_handler() self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, token, rows): + await super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) run_in_background(self.process_and_notify, stream_name, token, rows) def get_streams_to_replicate(self): diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 1257098f92..ba536d6f04 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -172,9 +172,8 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler): super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore()) self.user_directory = hs.get_user_directory_handler() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(UserDirectoryReplicationHandler, self).on_rdata( + async def on_rdata(self, stream_name, token, rows): + await super(UserDirectoryReplicationHandler, self).on_rdata( stream_name, token, rows ) if stream_name == EventsStream.NAME: diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index ced4925a98..174f6e42be 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -259,7 +259,9 @@ class FederationRemoteSendQueue(object): def federation_ack(self, token): self._clear_queue_before_pos(token) - def get_replication_rows(self, from_token, to_token, limit, federation_ack=None): + async def get_replication_rows( + self, from_token, to_token, limit, federation_ack=None + ): """Get rows to be sent over federation between the two tokens Args: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index b635c339ed..d5ca9cb07b 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -257,7 +257,7 @@ class TypingHandler(object): "typing_key", self._latest_room_serial, rooms=[member.room_id] ) - def get_all_typing_updates(self, last_id, current_id): + async def get_all_typing_updates(self, last_id, current_id): if last_id == current_id: return [] diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index aa7fd90e26..52a0aefe68 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -110,7 +110,7 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self.factory) - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to @@ -121,20 +121,17 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): token (int): stream token for this batch of rows rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. - - Returns: - Deferred|None """ logger.debug("Received rdata %s -> %s", stream_name, token) - return self.store.process_replication_rows(stream_name, token, rows) + self.store.process_replication_rows(stream_name, token, rows) - def on_position(self, stream_name, token): + async def on_position(self, stream_name, token): """Called when we get new position data. By default this just pokes the slave store. Can be overriden in subclasses to handle more. """ - return self.store.process_replication_rows(stream_name, token, []) + self.store.process_replication_rows(stream_name, token, []) def on_sync(self, data): """When we received a SYNC we wake up any deferreds that were waiting diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index db0353c996..5f4bdf84d2 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -81,12 +81,11 @@ from synapse.replication.tcp.commands import ( SyncCommand, UserSyncCommand, ) +from synapse.replication.tcp.streams import STREAMS_MAP from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string -from .streams import STREAMS_MAP - connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) @@ -241,19 +240,16 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd ) - def handle_command(self, cmd): + async def handle_command(self, cmd: Command): """Handle a command we have received over the replication stream. - By default delegates to on_ + By default delegates to on_, which should return an awaitable. Args: - cmd (synapse.replication.tcp.commands.Command): received command - - Returns: - Deferred + cmd: received command """ handler = getattr(self, "on_%s" % (cmd.NAME,)) - return handler(cmd) + await handler(cmd) def close(self): logger.warning("[%s] Closing connection", self.id()) @@ -326,10 +322,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): for cmd in pending: self.send_command(cmd) - def on_PING(self, line): + async def on_PING(self, line): self.received_ping = True - def on_ERROR(self, cmd): + async def on_ERROR(self, cmd): logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) def pauseProducing(self): @@ -429,16 +425,16 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): BaseReplicationStreamProtocol.connectionMade(self) self.streamer.new_connection(self) - def on_NAME(self, cmd): + async def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data - def on_USER_SYNC(self, cmd): - return self.streamer.on_user_sync( + async def on_USER_SYNC(self, cmd): + await self.streamer.on_user_sync( self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) - def on_REPLICATE(self, cmd): + async def on_REPLICATE(self, cmd): stream_name = cmd.stream_name token = cmd.token @@ -449,23 +445,23 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): for stream in iterkeys(self.streamer.streams_by_name) ] - return make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) else: - return self.subscribe_to_stream(stream_name, token) + await self.subscribe_to_stream(stream_name, token) - def on_FEDERATION_ACK(self, cmd): - return self.streamer.federation_ack(cmd.token) + async def on_FEDERATION_ACK(self, cmd): + self.streamer.federation_ack(cmd.token) - def on_REMOVE_PUSHER(self, cmd): - return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) + async def on_REMOVE_PUSHER(self, cmd): + await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) - def on_INVALIDATE_CACHE(self, cmd): - return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + async def on_INVALIDATE_CACHE(self, cmd): + self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) - def on_USER_IP(self, cmd): - return self.streamer.on_user_ip( + async def on_USER_IP(self, cmd): + self.streamer.on_user_ip( cmd.user_id, cmd.access_token, cmd.ip, @@ -474,8 +470,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): cmd.last_seen, ) - @defer.inlineCallbacks - def subscribe_to_stream(self, stream_name, token): + async def subscribe_to_stream(self, stream_name, token): """Subscribe the remote to a stream. This invloves checking if they've missed anything and sending those @@ -487,7 +482,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): try: # Get missing updates - updates, current_token = yield self.streamer.get_stream_updates( + updates, current_token = await self.streamer.get_stream_updates( stream_name, token ) @@ -572,7 +567,7 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): """Called to handle a batch of replication data with a given stream token. Args: @@ -580,14 +575,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): token (int): stream token for this batch of rows rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. - - Returns: - Deferred|None """ raise NotImplementedError() @abc.abstractmethod - def on_position(self, stream_name, token): + async def on_position(self, stream_name, token): """Called when we get new position data.""" raise NotImplementedError() @@ -676,12 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): if not self.streams_connecting: self.handler.finished_connecting() - def on_SERVER(self, cmd): + async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") - def on_RDATA(self, cmd): + async def on_RDATA(self, cmd): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -701,19 +693,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Check if this is the last of a batch of updates rows = self.pending_batches.pop(stream_name, []) rows.append(row) - return self.handler.on_rdata(stream_name, cmd.token, rows) + await self.handler.on_rdata(stream_name, cmd.token, rows) - def on_POSITION(self, cmd): + async def on_POSITION(self, cmd): # When we get a `POSITION` command it means we've finished getting # missing updates for the given stream, and are now up to date. self.streams_connecting.discard(cmd.stream_name) if not self.streams_connecting: self.handler.finished_connecting() - return self.handler.on_position(cmd.stream_name, cmd.token) + await self.handler.on_position(cmd.stream_name, cmd.token) - def on_SYNC(self, cmd): - return self.handler.on_sync(cmd.data) + async def on_SYNC(self, cmd): + self.handler.on_sync(cmd.data) def replicate(self, stream_name, token): """Send the subscription request to the server diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index cbfdaf5773..b1752e88cd 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -23,7 +23,6 @@ from six import itervalues from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.protocol import Factory from synapse.metrics import LaterGauge @@ -155,8 +154,7 @@ class ReplicationStreamer(object): run_as_background_process("replication_notifier", self._run_notifier_loop) - @defer.inlineCallbacks - def _run_notifier_loop(self): + async def _run_notifier_loop(self): self.is_looping = True try: @@ -185,7 +183,7 @@ class ReplicationStreamer(object): continue if self._replication_torture_level: - yield self.clock.sleep( + await self.clock.sleep( self._replication_torture_level / 1000.0 ) @@ -196,7 +194,7 @@ class ReplicationStreamer(object): stream.upto_token, ) try: - updates, current_token = yield stream.get_updates() + updates, current_token = await stream.get_updates() except Exception: logger.info("Failed to handle stream %s", stream.NAME) raise @@ -233,7 +231,7 @@ class ReplicationStreamer(object): self.is_looping = False @measure_func("repl.get_stream_updates") - def get_stream_updates(self, stream_name, token): + async def get_stream_updates(self, stream_name, token): """For a given stream get all updates since token. This is called when a client first subscribes to a stream. """ @@ -241,7 +239,7 @@ class ReplicationStreamer(object): if not stream: raise Exception("unknown stream %s", stream_name) - return stream.get_updates_since(token) + return await stream.get_updates_since(token) @measure_func("repl.federation_ack") def federation_ack(self, token): @@ -252,22 +250,20 @@ class ReplicationStreamer(object): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") - @defer.inlineCallbacks - def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): + async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() - yield self.presence_handler.update_external_syncs_row( + await self.presence_handler.update_external_syncs_row( conn_id, user_id, is_syncing, last_sync_ms ) @measure_func("repl.on_remove_pusher") - @defer.inlineCallbacks - def on_remove_pusher(self, app_id, push_key, user_id): + async def on_remove_pusher(self, app_id, push_key, user_id): """A client has asked us to remove a pusher """ remove_pusher_counter.inc() - yield self.store.delete_pusher_by_app_id_pushkey_user_id( + await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id=app_id, pushkey=push_key, user_id=user_id ) @@ -281,15 +277,16 @@ class ReplicationStreamer(object): getattr(self.store, cache_func).invalidate(tuple(keys)) @measure_func("repl.on_user_ip") - @defer.inlineCallbacks - def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): + async def on_user_ip( + self, user_id, access_token, ip, user_agent, device_id, last_seen + ): """The client saw a user request """ user_ip_cache_counter.inc() - yield self.store.insert_client_ip( + await self.store.insert_client_ip( user_id, access_token, ip, user_agent, device_id, last_seen ) - yield self._server_notices_sender.on_user_ip(user_id) + await self._server_notices_sender.on_user_ip(user_id) def send_sync_to_all_connections(self, data): """Sends a SYNC command to all clients. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4ab0334fc1..e03e77199b 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -19,8 +19,6 @@ import logging from collections import namedtuple from typing import Any -from twisted.internet import defer - logger = logging.getLogger(__name__) @@ -144,8 +142,7 @@ class Stream(object): self.upto_token = self.current_token() self.last_token = self.upto_token - @defer.inlineCallbacks - def get_updates(self): + async def get_updates(self): """Gets all updates since the last time this function was called (or since the stream was constructed if it hadn't been called before), until the `upto_token` @@ -156,13 +153,12 @@ class Stream(object): list of ``(token, row)`` entries. ``row`` will be json-serialised and sent over the replication steam. """ - updates, current_token = yield self.get_updates_since(self.last_token) + updates, current_token = await self.get_updates_since(self.last_token) self.last_token = current_token return updates, current_token - @defer.inlineCallbacks - def get_updates_since(self, from_token): + async def get_updates_since(self, from_token): """Like get_updates except allows specifying from when we should stream updates @@ -182,15 +178,16 @@ class Stream(object): if from_token == current_token: return [], current_token + logger.info("get_updates_since: %s", self.__class__) if self._LIMITED: - rows = yield self.update_function( + rows = await self.update_function( from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 ) # never turn more than MAX_EVENTS_BEHIND + 1 into updates. rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) else: - rows = yield self.update_function(from_token, current_token) + rows = await self.update_function(from_token, current_token) updates = [(row[0], row[1:]) for row in rows] @@ -295,9 +292,8 @@ class PushRulesStream(Stream): push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit) + async def update_function(self, from_token, to_token, limit): + rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) return [(row[0], row[2]) for row in rows] @@ -413,9 +409,8 @@ class AccountDataStream(Stream): super(AccountDataStream, self).__init__(hs) - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - global_results, room_results = yield self.store.get_all_updated_account_data( + async def update_function(self, from_token, to_token, limit): + global_results, room_results = await self.store.get_all_updated_account_data( from_token, from_token, to_token, limit ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 0843e5aa90..b3afabb8cd 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,8 +19,6 @@ from typing import Tuple, Type import attr -from twisted.internet import defer - from ._base import Stream @@ -122,16 +120,15 @@ class EventsStream(Stream): super(EventsStream, self).__init__(hs) - @defer.inlineCallbacks - def update_function(self, from_token, current_token, limit=None): - event_rows = yield self._store.get_all_new_forward_event_rows( + async def update_function(self, from_token, current_token, limit=None): + event_rows = await self._store.get_all_new_forward_event_rows( from_token, current_token, limit ) event_updates = ( (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows ) - state_rows = yield self._store.get_all_updated_current_state_deltas( + state_rows = await self._store.get_all_updated_current_state_deltas( from_token, current_token, limit ) state_updates = ( diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 1d14e77255..e96ad4ca4e 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -73,6 +73,6 @@ class TestReplicationClientHandler(object): def finished_connecting(self): pass - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): for r in rows: self.received_rdata_rows.append((stream_name, token, r)) -- cgit 1.5.1 From a8a50f5b5746279379b4511c8ecb2a40b143fe32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 Jan 2020 10:27:19 +0000 Subject: Wake up transaction queue when remote server comes back online (#6706) This will be used to retry outbound transactions to a remote server if we think it might have come back up. --- changelog.d/6706.misc | 1 + docs/tcp_replication.md | 6 +++++- synapse/app/federation_sender.py | 12 +++++++++++- synapse/federation/sender/__init__.py | 18 ++++++++++++++++-- synapse/federation/transport/server.py | 19 ++++++++++++++++++- synapse/notifier.py | 31 ++++++++++++++++++++++++++++--- synapse/replication/tcp/client.py | 3 +++ synapse/replication/tcp/commands.py | 17 +++++++++++++++++ synapse/replication/tcp/protocol.py | 15 +++++++++++++++ synapse/replication/tcp/resource.py | 9 +++++++++ synapse/server.pyi | 12 ++++++++++++ 11 files changed, 135 insertions(+), 8 deletions(-) create mode 100644 changelog.d/6706.misc (limited to 'synapse/replication/tcp/resource.py') diff --git a/changelog.d/6706.misc b/changelog.d/6706.misc new file mode 100644 index 0000000000..1ac11cc04b --- /dev/null +++ b/changelog.d/6706.misc @@ -0,0 +1 @@ +Attempt to retry sending a transaction when we detect a remote server has come back online, rather than waiting for a transaction to be triggered by new data. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index ba9e874d07..a0b1d563ff 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -209,7 +209,7 @@ Where `` may be either: * a numeric stream_id to stream updates since (exclusive) * `NOW` to stream all subsequent updates. -The `` is the name of a replication stream to subscribe +The `` is the name of a replication stream to subscribe to (see [here](../synapse/replication/tcp/streams/_base.py) for a list of streams). It can also be `ALL` to subscribe to all known streams, in which case the `` must be set to `NOW`. @@ -234,6 +234,10 @@ in which case the `` must be set to `NOW`. Used exclusively in tests +### REMOTE_SERVER_UP (S, C) + + Inform other processes that a remote server may have come back online. + See `synapse/replication/tcp/commands.py` for a detailed description and the format of each command. diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index a57cf991ac..38d11fdd0f 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -158,6 +158,13 @@ class FederationSenderReplicationHandler(ReplicationClientHandler): args.update(self.send_handler.stream_positions()) return args + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + + # Let's wake up the transaction queue for the server in case we have + # pending stuff to send to it. + self.send_handler.wake_destination(server) + def start(config_options): try: @@ -205,7 +212,7 @@ class FederationSenderHandler(object): to the federation sender. """ - def __init__(self, hs, replication_client): + def __init__(self, hs: FederationSenderServer, replication_client): self.store = hs.get_datastore() self._is_mine_id = hs.is_mine_id self.federation_sender = hs.get_federation_sender() @@ -226,6 +233,9 @@ class FederationSenderHandler(object): self.store.get_room_max_stream_ordering() ) + def wake_destination(self, server: str): + self.federation_sender.wake_destination(server) + def stream_positions(self): return {"federation": self.federation_position} diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 4ebb0e8bc0..36c83c3027 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -21,6 +21,7 @@ from prometheus_client import Counter from twisted.internet import defer +import synapse import synapse.metrics from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.transaction_manager import TransactionManager @@ -54,7 +55,7 @@ sent_pdus_destination_dist_total = Counter( class FederationSender(object): - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.hs = hs self.server_name = hs.hostname @@ -482,7 +483,20 @@ class FederationSender(object): def send_device_messages(self, destination): if destination == self.server_name: - logger.info("Not sending device update to ourselves") + logger.warning("Not sending device update to ourselves") + return + + self._get_per_destination_queue(destination).attempt_new_transaction() + + def wake_destination(self, destination: str): + """Called when we want to retry sending transactions to a remote. + + This is mainly useful if the remote server has been down and we think it + might have come back. + """ + + if destination == self.server_name: + logger.warning("Not waking up ourselves") return self._get_per_destination_queue(destination).attempt_new_transaction() diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index b4cbf23394..d8cf9ed299 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -44,6 +44,7 @@ from synapse.logging.opentracing import ( tags, whitelisted_homeserver, ) +from synapse.server import HomeServer from synapse.types import ThirdPartyInstanceID, get_domain_from_id from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string @@ -101,12 +102,17 @@ class NoAuthenticationError(AuthenticationError): class Authenticator(object): - def __init__(self, hs): + def __init__(self, hs: HomeServer): self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname self.store = hs.get_datastore() self.federation_domain_whitelist = hs.config.federation_domain_whitelist + self.notifer = hs.get_notifier() + + self.replication_client = None + if hs.config.worker.worker_app: + self.replication_client = hs.get_tcp_replication() # A method just so we can pass 'self' as the authenticator to the Servlets async def authenticate_request(self, request, content): @@ -166,6 +172,17 @@ class Authenticator(object): try: logger.info("Marking origin %r as up", origin) await self.store.set_destination_retry_timings(origin, None, 0, 0) + + # Inform the relevant places that the remote server is back up. + self.notifer.notify_remote_server_up(origin) + if self.replication_client: + # If we're on a worker we try and inform master about this. The + # replication client doesn't hook into the notifier to avoid + # infinite loops where we send a `REMOTE_SERVER_UP` command to + # master, which then echoes it back to us which in turn pokes + # the notifier. + self.replication_client.send_remote_server_up(origin) + except Exception: logger.exception("Error resetting retry timings on %s", origin) diff --git a/synapse/notifier.py b/synapse/notifier.py index 5f5f765bea..6132727cbd 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -15,11 +15,13 @@ import logging from collections import namedtuple +from typing import Callable, List from prometheus_client import Counter from twisted.internet import defer +import synapse.server from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state @@ -154,7 +156,7 @@ class Notifier(object): UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.user_to_user_stream = {} self.room_to_user_streams = {} @@ -164,7 +166,12 @@ class Notifier(object): self.store = hs.get_datastore() self.pending_new_room_events = [] - self.replication_callbacks = [] + # Called when there are new things to stream over replication + self.replication_callbacks = [] # type: List[Callable[[], None]] + + # Called when remote servers have come back online after having been + # down. + self.remote_server_up_callbacks = [] # type: List[Callable[[str], None]] self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() @@ -205,7 +212,7 @@ class Notifier(object): "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream) ) - def add_replication_callback(self, cb): + def add_replication_callback(self, cb: Callable[[], None]): """Add a callback that will be called when some new data is available. Callback is not given any arguments. It should *not* return a Deferred - if it needs to do any asynchronous work, a background thread should be started and @@ -213,6 +220,12 @@ class Notifier(object): """ self.replication_callbacks.append(cb) + def add_remote_server_up_callback(self, cb: Callable[[str], None]): + """Add a callback that will be called when synapse detects a server + has been + """ + self.remote_server_up_callbacks.append(cb) + def on_new_room_event( self, event, room_stream_id, max_room_stream_id, extra_users=[] ): @@ -522,3 +535,15 @@ class Notifier(object): """Notify the any replication listeners that there's a new event""" for cb in self.replication_callbacks: cb() + + def notify_remote_server_up(self, server: str): + """Notify any replication that a remote server has come back up + """ + # We call federation_sender directly rather than registering as a + # callback as a) we already have a reference to it and b) it introduces + # circular dependencies. + if self.federation_sender: + self.federation_sender.wake_destination(server) + + for cb in self.remote_server_up_callbacks: + cb(server) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 52a0aefe68..fc06a7b053 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -143,6 +143,9 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): if d: d.callback(data) + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + def get_streams_to_replicate(self) -> Dict[str, int]: """Called when a new connection has been established and we need to subscribe to streams. diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index cbb36b9acf..451671412d 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -387,6 +387,20 @@ class UserIpCommand(Command): ) +class RemoteServerUpCommand(Command): + """Sent when a worker has detected that a remote server is no longer + "down" and retry timings should be reset. + + If sent from a client the server will relay to all other workers. + + Format:: + + REMOTE_SERVER_UP + """ + + NAME = "REMOTE_SERVER_UP" + + _COMMANDS = ( ServerCommand, RdataCommand, @@ -401,6 +415,7 @@ _COMMANDS = ( RemovePusherCommand, InvalidateCacheCommand, UserIpCommand, + RemoteServerUpCommand, ) # type: Tuple[Type[Command], ...] # Map of command name to command type. @@ -414,6 +429,7 @@ VALID_SERVER_COMMANDS = ( ErrorCommand.NAME, PingCommand.NAME, SyncCommand.NAME, + RemoteServerUpCommand.NAME, ) # The commands the client is allowed to send @@ -427,4 +443,5 @@ VALID_CLIENT_COMMANDS = ( InvalidateCacheCommand.NAME, UserIpCommand.NAME, ErrorCommand.NAME, + RemoteServerUpCommand.NAME, ) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 5f4bdf84d2..131e5acb09 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -76,6 +76,7 @@ from synapse.replication.tcp.commands import ( PingCommand, PositionCommand, RdataCommand, + RemoteServerUpCommand, ReplicateCommand, ServerCommand, SyncCommand, @@ -460,6 +461,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): async def on_INVALIDATE_CACHE(self, cmd): self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + self.streamer.on_remote_server_up(cmd.data) + async def on_USER_IP(self, cmd): self.streamer.on_user_ip( cmd.user_id, @@ -555,6 +559,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): def send_sync(self, data): self.send_command(SyncCommand(data)) + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) + def on_connection_closed(self): BaseReplicationStreamProtocol.on_connection_closed(self) self.streamer.lost_connection(self) @@ -588,6 +595,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): """Called when get a new SYNC command.""" raise NotImplementedError() + @abc.abstractmethod + async def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + raise NotImplementedError() + @abc.abstractmethod def get_streams_to_replicate(self): """Called when a new connection has been established and we need to @@ -707,6 +719,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): async def on_SYNC(self, cmd): self.handler.on_sync(cmd.data) + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + self.handler.on_remote_server_up(cmd.data) + def replicate(self, stream_name, token): """Send the subscription request to the server """ diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index b1752e88cd..6ebf944f66 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -120,6 +120,7 @@ class ReplicationStreamer(object): self.federation_sender = hs.get_federation_sender() self.notifier.add_replication_callback(self.on_notifier_poke) + self.notifier.add_remote_server_up_callback(self.send_remote_server_up) # Keeps track of whether we are currently checking for updates self.is_looping = False @@ -288,6 +289,14 @@ class ReplicationStreamer(object): ) await self._server_notices_sender.on_user_ip(user_id) + @measure_func("repl.on_remote_server_up") + def on_remote_server_up(self, server: str): + self.notifier.notify_remote_server_up(server) + + def send_remote_server_up(self, server: str): + for conn in self.connections: + conn.send_remote_server_up(server) + def send_sync_to_all_connections(self, data): """Sends a SYNC command to all clients. diff --git a/synapse/server.pyi b/synapse/server.pyi index b5e0b57095..0731403047 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,3 +1,5 @@ +import twisted.internet + import synapse.api.auth import synapse.config.homeserver import synapse.federation.sender @@ -9,10 +11,12 @@ import synapse.handlers.deactivate_account import synapse.handlers.device import synapse.handlers.e2e_keys import synapse.handlers.message +import synapse.handlers.presence import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password import synapse.http.client +import synapse.notifier import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender @@ -85,3 +89,11 @@ class HomeServer(object): self, ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: pass + def get_notifier(self) -> synapse.notifier.Notifier: + pass + def get_presence_handler(self) -> synapse.handlers.presence.PresenceHandler: + pass + def get_clock(self) -> synapse.util.Clock: + pass + def get_reactor(self) -> twisted.internet.base.ReactorBase: + pass -- cgit 1.5.1 From d5275fc55f4edc42d1543825da2c13df63d96927 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 27 Jan 2020 13:47:50 +0000 Subject: Propagate cache invalidates from workers to other workers. (#6748) Currently if a worker invalidates a cache it will be streamed to master, which then didn't forward those to other workers. --- changelog.d/6748.misc | 1 + synapse/replication/tcp/protocol.py | 2 +- synapse/replication/tcp/resource.py | 9 ++++++--- synapse/storage/data_stores/main/cache.py | 22 +++++++++++++++++++++- 4 files changed, 29 insertions(+), 5 deletions(-) create mode 100644 changelog.d/6748.misc (limited to 'synapse/replication/tcp/resource.py') diff --git a/changelog.d/6748.misc b/changelog.d/6748.misc new file mode 100644 index 0000000000..de320d4cd9 --- /dev/null +++ b/changelog.d/6748.misc @@ -0,0 +1 @@ +Propagate cache invalidates from workers to other workers. diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 131e5acb09..bc1482a9bb 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -459,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) async def on_INVALIDATE_CACHE(self, cmd): - self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): self.streamer.on_remote_server_up(cmd.data) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 6ebf944f66..ce60ae2e07 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -17,7 +17,7 @@ import logging import random -from typing import List +from typing import Any, List from six import itervalues @@ -271,11 +271,14 @@ class ReplicationStreamer(object): self.notifier.on_new_replication_data() @measure_func("repl.on_invalidate_cache") - def on_invalidate_cache(self, cache_func, keys): + async def on_invalidate_cache(self, cache_func: str, keys: List[Any]): """The client has asked us to invalidate a cache """ invalidate_cache_counter.inc() - getattr(self.store, cache_func).invalidate(tuple(keys)) + + # We invalidate the cache locally, but then also stream that to other + # workers. + await self.store.invalidate_cache_and_stream(cache_func, tuple(keys)) @measure_func("repl.on_user_ip") async def on_user_ip( diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index afa2b41c98..d4c44dcc75 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -16,7 +16,7 @@ import itertools import logging -from typing import Any, Iterable, Optional +from typing import Any, Iterable, Optional, Tuple from twisted.internet import defer @@ -33,6 +33,26 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" class CacheInvalidationStore(SQLBaseStore): + async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + cache_func = getattr(self, cache_name, None) + if not cache_func: + return + + cache_func.invalidate(keys) + await self.runInteraction( + "invalidate_cache_and_stream", + self._send_invalidation_to_replication, + cache_func.__name__, + keys, + ) + def _invalidate_cache_and_stream(self, txn, cache_func, keys): """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. -- cgit 1.5.1 From 1f773eec912e4908ab60f7823f5c0a024261af4d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 26 Feb 2020 15:33:26 +0000 Subject: Port PresenceHandler to async/await (#6991) --- changelog.d/6991.misc | 1 + synapse/handlers/message.py | 5 +- synapse/handlers/presence.py | 192 ++++++++++++++++-------------------- synapse/replication/tcp/resource.py | 6 +- synapse/server.pyi | 5 + tests/handlers/test_presence.py | 18 ++-- tox.ini | 1 + 7 files changed, 113 insertions(+), 115 deletions(-) create mode 100644 changelog.d/6991.misc (limited to 'synapse/replication/tcp/resource.py') diff --git a/changelog.d/6991.misc b/changelog.d/6991.misc new file mode 100644 index 0000000000..5130f4e8af --- /dev/null +++ b/changelog.d/6991.misc @@ -0,0 +1 @@ +Port `synapse.handlers.presence` to async/await. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d6be280952..a0103addd3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1016,11 +1016,10 @@ class EventCreationHandler(object): # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) - @defer.inlineCallbacks - def _bump_active_time(self, user): + async def _bump_active_time(self, user): try: presence = self.hs.get_presence_handler() - yield presence.bump_presence_active_time(user) + await presence.bump_presence_active_time(user) except Exception: logger.exception("Error bumping presence active time") diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 0d6cf2b008..5526015ddb 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -24,11 +24,12 @@ The methods that define policy are: import logging from contextlib import contextmanager -from typing import Dict, Set +from typing import Dict, List, Set from six import iteritems, itervalues from prometheus_client import Counter +from typing_extensions import ContextManager from twisted.internet import defer @@ -42,10 +43,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer +MYPY = False +if MYPY: + import synapse.server + logger = logging.getLogger(__name__) @@ -97,7 +102,6 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER class PresenceHandler(object): def __init__(self, hs: "synapse.server.HomeServer"): self.hs = hs - self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id self.server_name = hs.hostname self.clock = hs.get_clock() @@ -150,7 +154,7 @@ class PresenceHandler(object): # Set of users who have presence in the `user_to_current_state` that # have not yet been persisted - self.unpersisted_users_changes = set() + self.unpersisted_users_changes = set() # type: Set[str] hs.get_reactor().addSystemEventTrigger( "before", @@ -160,12 +164,11 @@ class PresenceHandler(object): self._on_shutdown, ) - self.serial_to_user = {} self._next_serial = 1 # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self.user_to_num_current_syncs = {} + self.user_to_num_current_syncs = {} # type: Dict[str, int] # Keeps track of the number of *ongoing* syncs on other processes. # While any sync is ongoing on another process the user will never @@ -213,8 +216,7 @@ class PresenceHandler(object): self._event_pos = self.store.get_current_events_token() self._event_processing = False - @defer.inlineCallbacks - def _on_shutdown(self): + async def _on_shutdown(self): """Gets called when shutting down. This lets us persist any updates that we haven't yet persisted, e.g. updates that only changes some internal timers. This allows changes to persist across startup without having to @@ -235,7 +237,7 @@ class PresenceHandler(object): if self.unpersisted_users_changes: - yield self.store.update_presence( + await self.store.update_presence( [ self.user_to_current_state[user_id] for user_id in self.unpersisted_users_changes @@ -243,8 +245,7 @@ class PresenceHandler(object): ) logger.info("Finished _on_shutdown") - @defer.inlineCallbacks - def _persist_unpersisted_changes(self): + async def _persist_unpersisted_changes(self): """We periodically persist the unpersisted changes, as otherwise they may stack up and slow down shutdown times. """ @@ -253,12 +254,11 @@ class PresenceHandler(object): if unpersisted: logger.info("Persisting %d unpersisted presence updates", len(unpersisted)) - yield self.store.update_presence( + await self.store.update_presence( [self.user_to_current_state[user_id] for user_id in unpersisted] ) - @defer.inlineCallbacks - def _update_states(self, new_states): + async def _update_states(self, new_states): """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state should be sent to clients/servers. @@ -267,7 +267,7 @@ class PresenceHandler(object): with Measure(self.clock, "presence_update_states"): - # NOTE: We purposefully don't yield between now and when we've + # NOTE: We purposefully don't await between now and when we've # calculated what we want to do with the new states, to avoid races. to_notify = {} # Changes we want to notify everyone about @@ -311,7 +311,7 @@ class PresenceHandler(object): if to_notify: notified_presence_counter.inc(len(to_notify)) - yield self._persist_and_notify(list(to_notify.values())) + await self._persist_and_notify(list(to_notify.values())) self.unpersisted_users_changes |= {s.user_id for s in new_states} self.unpersisted_users_changes -= set(to_notify.keys()) @@ -326,7 +326,7 @@ class PresenceHandler(object): self._push_to_remotes(to_federation_ping.values()) - def _handle_timeouts(self): + async def _handle_timeouts(self): """Checks the presence of users that have timed out and updates as appropriate. """ @@ -368,10 +368,9 @@ class PresenceHandler(object): now=now, ) - return self._update_states(changes) + return await self._update_states(changes) - @defer.inlineCallbacks - def bump_presence_active_time(self, user): + async def bump_presence_active_time(self, user): """We've seen the user do something that indicates they're interacting with the app. """ @@ -383,16 +382,17 @@ class PresenceHandler(object): bump_active_time_counter.inc() - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) new_fields = {"last_active_ts": self.clock.time_msec()} if prev_state.state == PresenceState.UNAVAILABLE: new_fields["state"] = PresenceState.ONLINE - yield self._update_states([prev_state.copy_and_replace(**new_fields)]) + await self._update_states([prev_state.copy_and_replace(**new_fields)]) - @defer.inlineCallbacks - def user_syncing(self, user_id, affect_presence=True): + async def user_syncing( + self, user_id: str, affect_presence: bool = True + ) -> ContextManager[None]: """Returns a context manager that should surround any stream requests from the user. @@ -415,11 +415,11 @@ class PresenceHandler(object): curr_sync = self.user_to_num_current_syncs.get(user_id, 0) self.user_to_num_current_syncs[user_id] = curr_sync + 1 - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) if prev_state.state == PresenceState.OFFLINE: # If they're currently offline then bring them online, otherwise # just update the last sync times. - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace( state=PresenceState.ONLINE, @@ -429,7 +429,7 @@ class PresenceHandler(object): ] ) else: - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace( last_user_sync_ts=self.clock.time_msec() @@ -437,13 +437,12 @@ class PresenceHandler(object): ] ) - @defer.inlineCallbacks - def _end(): + async def _end(): try: self.user_to_num_current_syncs[user_id] -= 1 - prev_state = yield self.current_state_for_user(user_id) - yield self._update_states( + prev_state = await self.current_state_for_user(user_id) + await self._update_states( [ prev_state.copy_and_replace( last_user_sync_ts=self.clock.time_msec() @@ -480,8 +479,7 @@ class PresenceHandler(object): else: return set() - @defer.inlineCallbacks - def update_external_syncs_row( + async def update_external_syncs_row( self, process_id, user_id, is_syncing, sync_time_msec ): """Update the syncing users for an external process as a delta. @@ -494,8 +492,8 @@ class PresenceHandler(object): is_syncing (bool): Whether or not the user is now syncing sync_time_msec(int): Time in ms when the user was last syncing """ - with (yield self.external_sync_linearizer.queue(process_id)): - prev_state = yield self.current_state_for_user(user_id) + with (await self.external_sync_linearizer.queue(process_id)): + prev_state = await self.current_state_for_user(user_id) process_presence = self.external_process_to_current_syncs.setdefault( process_id, set() @@ -525,25 +523,24 @@ class PresenceHandler(object): process_presence.discard(user_id) if updates: - yield self._update_states(updates) + await self._update_states(updates) self.external_process_last_updated_ms[process_id] = self.clock.time_msec() - @defer.inlineCallbacks - def update_external_syncs_clear(self, process_id): + async def update_external_syncs_clear(self, process_id): """Marks all users that had been marked as syncing by a given process as offline. Used when the process has stopped/disappeared. """ - with (yield self.external_sync_linearizer.queue(process_id)): + with (await self.external_sync_linearizer.queue(process_id)): process_presence = self.external_process_to_current_syncs.pop( process_id, set() ) - prev_states = yield self.current_state_for_users(process_presence) + prev_states = await self.current_state_for_users(process_presence) time_now_ms = self.clock.time_msec() - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) for prev_state in itervalues(prev_states) @@ -551,15 +548,13 @@ class PresenceHandler(object): ) self.external_process_last_updated_ms.pop(process_id, None) - @defer.inlineCallbacks - def current_state_for_user(self, user_id): + async def current_state_for_user(self, user_id): """Get the current presence state for a user. """ - res = yield self.current_state_for_users([user_id]) + res = await self.current_state_for_users([user_id]) return res[user_id] - @defer.inlineCallbacks - def current_state_for_users(self, user_ids): + async def current_state_for_users(self, user_ids): """Get the current presence state for multiple users. Returns: @@ -574,7 +569,7 @@ class PresenceHandler(object): if missing: # There are things not in our in memory cache. Lets pull them out of # the database. - res = yield self.store.get_presence_for_users(missing) + res = await self.store.get_presence_for_users(missing) states.update(res) missing = [user_id for user_id, state in iteritems(states) if not state] @@ -587,14 +582,13 @@ class PresenceHandler(object): return states - @defer.inlineCallbacks - def _persist_and_notify(self, states): + async def _persist_and_notify(self, states): """Persist states in the database, poke the notifier and send to interested remote servers """ - stream_id, max_token = yield self.store.update_presence(states) + stream_id, max_token = await self.store.update_presence(states) - parties = yield get_interested_parties(self.store, states) + parties = await get_interested_parties(self.store, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -606,9 +600,8 @@ class PresenceHandler(object): self._push_to_remotes(states) - @defer.inlineCallbacks - def notify_for_states(self, state, stream_id): - parties = yield get_interested_parties(self.store, [state]) + async def notify_for_states(self, state, stream_id): + parties = await get_interested_parties(self.store, [state]) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -626,8 +619,7 @@ class PresenceHandler(object): """ self.federation.send_presence(states) - @defer.inlineCallbacks - def incoming_presence(self, origin, content): + async def incoming_presence(self, origin, content): """Called when we receive a `m.presence` EDU from a remote server. """ now = self.clock.time_msec() @@ -670,21 +662,19 @@ class PresenceHandler(object): new_fields["status_msg"] = push.get("status_msg", None) new_fields["currently_active"] = push.get("currently_active", False) - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) updates.append(prev_state.copy_and_replace(**new_fields)) if updates: federation_presence_counter.inc(len(updates)) - yield self._update_states(updates) + await self._update_states(updates) - @defer.inlineCallbacks - def get_state(self, target_user, as_event=False): - results = yield self.get_states([target_user.to_string()], as_event=as_event) + async def get_state(self, target_user, as_event=False): + results = await self.get_states([target_user.to_string()], as_event=as_event) return results[0] - @defer.inlineCallbacks - def get_states(self, target_user_ids, as_event=False): + async def get_states(self, target_user_ids, as_event=False): """Get the presence state for users. Args: @@ -695,7 +685,7 @@ class PresenceHandler(object): list """ - updates = yield self.current_state_for_users(target_user_ids) + updates = await self.current_state_for_users(target_user_ids) updates = list(updates.values()) for user_id in set(target_user_ids) - {u.user_id for u in updates}: @@ -713,8 +703,7 @@ class PresenceHandler(object): else: return updates - @defer.inlineCallbacks - def set_state(self, target_user, state, ignore_status_msg=False): + async def set_state(self, target_user, state, ignore_status_msg=False): """Set the presence state of the user. """ status_msg = state.get("status_msg", None) @@ -730,7 +719,7 @@ class PresenceHandler(object): user_id = target_user.to_string() - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) new_fields = {"state": presence} @@ -741,16 +730,15 @@ class PresenceHandler(object): if presence == PresenceState.ONLINE: new_fields["last_active_ts"] = self.clock.time_msec() - yield self._update_states([prev_state.copy_and_replace(**new_fields)]) + await self._update_states([prev_state.copy_and_replace(**new_fields)]) - @defer.inlineCallbacks - def is_visible(self, observed_user, observer_user): + async def is_visible(self, observed_user, observer_user): """Returns whether a user can see another user's presence. """ - observer_room_ids = yield self.store.get_rooms_for_user( + observer_room_ids = await self.store.get_rooms_for_user( observer_user.to_string() ) - observed_room_ids = yield self.store.get_rooms_for_user( + observed_room_ids = await self.store.get_rooms_for_user( observed_user.to_string() ) @@ -759,8 +747,7 @@ class PresenceHandler(object): return False - @defer.inlineCallbacks - def get_all_presence_updates(self, last_id, current_id): + async def get_all_presence_updates(self, last_id, current_id): """ Gets a list of presence update rows from between the given stream ids. Each row has: @@ -775,7 +762,7 @@ class PresenceHandler(object): """ # TODO(markjh): replicate the unpersisted changes. # This could use the in-memory stores for recent changes. - rows = yield self.store.get_all_presence_updates(last_id, current_id) + rows = await self.store.get_all_presence_updates(last_id, current_id) return rows def notify_new_event(self): @@ -786,20 +773,18 @@ class PresenceHandler(object): if self._event_processing: return - @defer.inlineCallbacks - def _process_presence(): + async def _process_presence(): assert not self._event_processing self._event_processing = True try: - yield self._unsafe_process() + await self._unsafe_process() finally: self._event_processing = False run_as_background_process("presence.notify_new_event", _process_presence) - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "presence_delta"): @@ -812,10 +797,10 @@ class PresenceHandler(object): self._event_pos, room_max_stream_ordering, ) - max_pos, deltas = yield self.store.get_current_state_deltas( + max_pos, deltas = await self.store.get_current_state_deltas( self._event_pos, room_max_stream_ordering ) - yield self._handle_state_delta(deltas) + await self._handle_state_delta(deltas) self._event_pos = max_pos @@ -824,8 +809,7 @@ class PresenceHandler(object): max_pos ) - @defer.inlineCallbacks - def _handle_state_delta(self, deltas): + async def _handle_state_delta(self, deltas): """Process current state deltas to find new joins that need to be handled. """ @@ -846,13 +830,13 @@ class PresenceHandler(object): # joins. continue - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not event or event.content.get("membership") != Membership.JOIN: # We only care about joins continue if prev_event_id: - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) if ( prev_event and prev_event.content.get("membership") == Membership.JOIN @@ -860,10 +844,9 @@ class PresenceHandler(object): # Ignore changes to join events. continue - yield self._on_user_joined_room(room_id, state_key) + await self._on_user_joined_room(room_id, state_key) - @defer.inlineCallbacks - def _on_user_joined_room(self, room_id, user_id): + async def _on_user_joined_room(self, room_id, user_id): """Called when we detect a user joining the room via the current state delta stream. @@ -882,8 +865,8 @@ class PresenceHandler(object): # TODO: We should be able to filter the hosts down to those that # haven't previously seen the user - state = yield self.current_state_for_user(user_id) - hosts = yield self.state.get_current_hosts_in_room(room_id) + state = await self.current_state_for_user(user_id) + hosts = await self.state.get_current_hosts_in_room(room_id) # Filter out ourselves. hosts = {host for host in hosts if host != self.server_name} @@ -903,10 +886,10 @@ class PresenceHandler(object): # TODO: Check that this is actually a new server joining the # room. - user_ids = yield self.state.get_current_users_in_room(room_id) + user_ids = await self.state.get_current_users_in_room(room_id) user_ids = list(filter(self.is_mine_id, user_ids)) - states = yield self.current_state_for_users(user_ids) + states = await self.current_state_for_users(user_ids) # Filter out old presence, i.e. offline presence states where # the user hasn't been active for a week. We can change this @@ -996,9 +979,8 @@ class PresenceEventSource(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() - @defer.inlineCallbacks @log_function - def get_new_events( + async def get_new_events( self, user, from_key, @@ -1045,7 +1027,7 @@ class PresenceEventSource(object): presence = self.get_presence_handler() stream_change_cache = self.store.presence_stream_cache - users_interested_in = yield self._get_interested_in(user, explicit_room_id) + users_interested_in = await self._get_interested_in(user, explicit_room_id) user_ids_changed = set() changed = None @@ -1071,7 +1053,7 @@ class PresenceEventSource(object): else: user_ids_changed = users_interested_in - updates = yield presence.current_state_for_users(user_ids_changed) + updates = await presence.current_state_for_users(user_ids_changed) if include_offline: return (list(updates.values()), max_token) @@ -1084,11 +1066,11 @@ class PresenceEventSource(object): def get_current_key(self): return self.store.get_current_presence_token() - def get_pagination_rows(self, user, pagination_config, key): - return self.get_new_events(user, from_key=None, include_offline=False) + async def get_pagination_rows(self, user, pagination_config, key): + return await self.get_new_events(user, from_key=None, include_offline=False) - @cachedInlineCallbacks(num_args=2, cache_context=True) - def _get_interested_in(self, user, explicit_room_id, cache_context): + @cached(num_args=2, cache_context=True) + async def _get_interested_in(self, user, explicit_room_id, cache_context): """Returns the set of users that the given user should see presence updates for """ @@ -1096,13 +1078,13 @@ class PresenceEventSource(object): users_interested_in = set() users_interested_in.add(user_id) # So that we receive our own presence - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(users_who_share_room) if explicit_room_id: - user_ids = yield self.store.get_users_in_room( + user_ids = await self.store.get_users_in_room( explicit_room_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(user_ids) @@ -1277,8 +1259,8 @@ def get_interested_parties(store, states): 2-tuple: `(room_ids_to_states, users_to_states)`, with each item being a dict of `entity_name` -> `[UserPresenceState]` """ - room_ids_to_states = {} - users_to_states = {} + room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]] + users_to_states = {} # type: Dict[str, List[UserPresenceState]] for state in states: room_ids = yield store.get_rooms_for_user(state.user_id) for room_id in room_ids: diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index ce60ae2e07..ce9d1fae12 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -323,7 +323,11 @@ class ReplicationStreamer(object): # We need to tell the presence handler that the connection has been # lost so that it can handle any ongoing syncs on that connection. - self.presence_handler.update_external_syncs_clear(connection.conn_id) + run_as_background_process( + "update_external_syncs_clear", + self.presence_handler.update_external_syncs_clear, + connection.conn_id, + ) def _batch_updates(updates): diff --git a/synapse/server.pyi b/synapse/server.pyi index 40eabfe5d9..3844f0e12f 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -3,6 +3,7 @@ import twisted.internet import synapse.api.auth import synapse.config.homeserver import synapse.crypto.keyring +import synapse.federation.federation_server import synapse.federation.sender import synapse.federation.transport.client import synapse.handlers @@ -107,5 +108,9 @@ class HomeServer(object): self, ) -> synapse.replication.tcp.client.ReplicationClientHandler: pass + def get_federation_registry( + self, + ) -> synapse.federation.federation_server.FederationHandlerRegistry: + pass def is_mine_id(self, domain_id: str) -> bool: pass diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 64915bafcd..05ea40a7de 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -494,8 +494,10 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.helper.join(room_id, "@test2:server") # Mark test2 as online, test will be offline with a last_active of 0 - self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + ) ) self.reactor.pump([0]) # Wait for presence updates to be handled @@ -543,14 +545,18 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): room_id = self.helper.create_room_as(self.user_id) # Mark test as online - self.presence_handler.set_state( - UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + ) ) # Mark test2 as online, test will be offline with a last_active of 0. # Note we don't join them to the room yet - self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + ) ) # Add servers to the room diff --git a/tox.ini b/tox.ini index b715ea0bff..4ccfde01b5 100644 --- a/tox.ini +++ b/tox.ini @@ -183,6 +183,7 @@ commands = mypy \ synapse/events/spamcheck.py \ synapse/federation/sender \ synapse/federation/transport \ + synapse/handlers/presence.py \ synapse/handlers/sync.py \ synapse/handlers/ui_auth \ synapse/logging/ \ -- cgit 1.5.1