From 020add50997f697c7847ac84b86b457ba2f3e32d Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Fri, 1 Nov 2019 02:43:24 +1100 Subject: Update black to 19.10b0 (#6304) * update version of black and also fix the mypy config being overridden --- synapse/server.pyi | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'synapse/server.pyi') diff --git a/synapse/server.pyi b/synapse/server.pyi index 16f8f6b573..83d1f11283 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -39,7 +39,7 @@ class HomeServer(object): def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: pass def get_deactivate_account_handler( - self + self, ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: pass def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler: @@ -47,32 +47,32 @@ class HomeServer(object): def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler: pass def get_event_creation_handler( - self + self, ) -> synapse.handlers.message.EventCreationHandler: pass def get_set_password_handler( - self + self, ) -> synapse.handlers.set_password.SetPasswordHandler: pass def get_federation_sender(self) -> synapse.federation.sender.FederationSender: pass def get_federation_transport_client( - self + self, ) -> synapse.federation.transport.client.TransportLayerClient: pass def get_media_repository_resource( - self + self, ) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: pass def get_media_repository( - self + self, ) -> synapse.rest.media.v1.media_repository.MediaRepository: pass def get_server_notices_manager( - self + self, ) -> synapse.server_notices.server_notices_manager.ServerNoticesManager: pass def get_server_notices_sender( - self + self, ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: pass -- cgit 1.4.1 From 1cb84c6486a5131dd284f341bb657434becda255 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 1 Nov 2019 14:07:44 +0000 Subject: Support for routing outbound HTTP requests via a proxy (#6239) The `http_proxy` and `HTTPS_PROXY` env vars can be set to a `host[:port]` value which should point to a proxy. The address of the proxy should be excluded from IP blacklists such as the `url_preview_ip_range_blacklist`. The proxy will then be used for * push * url previews * phone-home stats * recaptcha validation * CAS auth validation It will *not* be used for: * Application Services * Identity servers * Outbound federation * In worker configurations, connections from workers to masters Fixes #4198. --- changelog.d/6238.feature | 1 + synapse/app/homeserver.py | 2 +- synapse/handlers/ui_auth/checkers.py | 2 +- synapse/http/client.py | 17 +- synapse/http/connectproxyclient.py | 195 ++++++++++++ synapse/http/proxyagent.py | 195 ++++++++++++ synapse/push/httppusher.py | 2 +- synapse/rest/client/v1/login.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 2 + synapse/server.py | 9 + synapse/server.pyi | 9 + tests/http/__init__.py | 17 ++ .../federation/test_matrix_federation_agent.py | 11 +- tests/http/test_proxyagent.py | 334 +++++++++++++++++++++ tests/push/test_http.py | 2 +- tests/server.py | 24 +- 16 files changed, 812 insertions(+), 12 deletions(-) create mode 100644 changelog.d/6238.feature create mode 100644 synapse/http/connectproxyclient.py create mode 100644 synapse/http/proxyagent.py create mode 100644 tests/http/test_proxyagent.py (limited to 'synapse/server.pyi') diff --git a/changelog.d/6238.feature b/changelog.d/6238.feature new file mode 100644 index 0000000000..d225ac33b6 --- /dev/null +++ b/changelog.d/6238.feature @@ -0,0 +1 @@ +Add support for outbound http proxying via http_proxy/HTTPS_PROXY env vars. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 8997c1f9e7..8d28076d92 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -565,7 +565,7 @@ def run(hs): "Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats) ) try: - yield hs.get_simple_http_client().put_json( + yield hs.get_proxied_http_client().put_json( hs.config.report_stats_endpoint, stats ) except Exception as e: diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 29aa1e5aaf..8363d887a9 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -81,7 +81,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs): super().__init__(hs) self._enabled = bool(hs.config.recaptcha_private_key) - self._http_client = hs.get_simple_http_client() + self._http_client = hs.get_proxied_http_client() self._url = hs.config.recaptcha_siteverify_api self._secret = hs.config.recaptcha_private_key diff --git a/synapse/http/client.py b/synapse/http/client.py index 2df5b383b5..d4c285445e 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -45,6 +45,7 @@ from synapse.http import ( cancelled_to_request_timed_out_error, redact_uri, ) +from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.util.async_helpers import timeout_deferred @@ -183,7 +184,15 @@ class SimpleHttpClient(object): using HTTP in Matrix """ - def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None): + def __init__( + self, + hs, + treq_args={}, + ip_whitelist=None, + ip_blacklist=None, + http_proxy=None, + https_proxy=None, + ): """ Args: hs (synapse.server.HomeServer) @@ -192,6 +201,8 @@ class SimpleHttpClient(object): we may not request. ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can request if it were otherwise caught in a blacklist. + http_proxy (bytes): proxy server to use for http connections. host[:port] + https_proxy (bytes): proxy server to use for https connections. host[:port] """ self.hs = hs @@ -236,11 +247,13 @@ class SimpleHttpClient(object): # The default context factory in Twisted 14.0.0 (which we require) is # BrowserLikePolicyForHTTPS which will do regular cert validation # 'like a browser' - self.agent = Agent( + self.agent = ProxyAgent( self.reactor, connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, + http_proxy=http_proxy, + https_proxy=https_proxy, ) if self._ip_blacklist: diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py new file mode 100644 index 0000000000..be7b2ceb8e --- /dev/null +++ b/synapse/http/connectproxyclient.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from zope.interface import implementer + +from twisted.internet import defer, protocol +from twisted.internet.error import ConnectError +from twisted.internet.interfaces import IStreamClientEndpoint +from twisted.internet.protocol import connectionDone +from twisted.web import http + +logger = logging.getLogger(__name__) + + +class ProxyConnectError(ConnectError): + pass + + +@implementer(IStreamClientEndpoint) +class HTTPConnectProxyEndpoint(object): + """An Endpoint implementation which will send a CONNECT request to an http proxy + + Wraps an existing HostnameEndpoint for the proxy. + + When we get the connect() request from the connection pool (via the TLS wrapper), + we'll first connect to the proxy endpoint with a ProtocolFactory which will make the + CONNECT request. Once that completes, we invoke the protocolFactory which was passed + in. + + Args: + reactor: the Twisted reactor to use for the connection + proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the + proxy + host (bytes): hostname that we want to CONNECT to + port (int): port that we want to connect to + """ + + def __init__(self, reactor, proxy_endpoint, host, port): + self._reactor = reactor + self._proxy_endpoint = proxy_endpoint + self._host = host + self._port = port + + def __repr__(self): + return "" % (self._proxy_endpoint,) + + def connect(self, protocolFactory): + f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory) + d = self._proxy_endpoint.connect(f) + # once the tcp socket connects successfully, we need to wait for the + # CONNECT to complete. + d.addCallback(lambda conn: f.on_connection) + return d + + +class HTTPProxiedClientFactory(protocol.ClientFactory): + """ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect. + + Once the CONNECT completes, invokes the original ClientFactory to build the + HTTP Protocol object and run the rest of the connection. + + Args: + dst_host (bytes): hostname that we want to CONNECT to + dst_port (int): port that we want to connect to + wrapped_factory (protocol.ClientFactory): The original Factory + """ + + def __init__(self, dst_host, dst_port, wrapped_factory): + self.dst_host = dst_host + self.dst_port = dst_port + self.wrapped_factory = wrapped_factory + self.on_connection = defer.Deferred() + + def startedConnecting(self, connector): + return self.wrapped_factory.startedConnecting(connector) + + def buildProtocol(self, addr): + wrapped_protocol = self.wrapped_factory.buildProtocol(addr) + + return HTTPConnectProtocol( + self.dst_host, self.dst_port, wrapped_protocol, self.on_connection + ) + + def clientConnectionFailed(self, connector, reason): + logger.debug("Connection to proxy failed: %s", reason) + if not self.on_connection.called: + self.on_connection.errback(reason) + return self.wrapped_factory.clientConnectionFailed(connector, reason) + + def clientConnectionLost(self, connector, reason): + logger.debug("Connection to proxy lost: %s", reason) + if not self.on_connection.called: + self.on_connection.errback(reason) + return self.wrapped_factory.clientConnectionLost(connector, reason) + + +class HTTPConnectProtocol(protocol.Protocol): + """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect + + Args: + host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal + to put in the CONNECT request + + port (int): The original HTTP(s) port to put in the CONNECT request + + wrapped_protocol (interfaces.IProtocol): the original protocol (probably + HTTPChannel or TLSMemoryBIOProtocol, but could be anything really) + + connected_deferred (Deferred): a Deferred which will be callbacked with + wrapped_protocol when the CONNECT completes + """ + + def __init__(self, host, port, wrapped_protocol, connected_deferred): + self.host = host + self.port = port + self.wrapped_protocol = wrapped_protocol + self.connected_deferred = connected_deferred + self.http_setup_client = HTTPConnectSetupClient(self.host, self.port) + self.http_setup_client.on_connected.addCallback(self.proxyConnected) + + def connectionMade(self): + self.http_setup_client.makeConnection(self.transport) + + def connectionLost(self, reason=connectionDone): + if self.wrapped_protocol.connected: + self.wrapped_protocol.connectionLost(reason) + + self.http_setup_client.connectionLost(reason) + + if not self.connected_deferred.called: + self.connected_deferred.errback(reason) + + def proxyConnected(self, _): + self.wrapped_protocol.makeConnection(self.transport) + + self.connected_deferred.callback(self.wrapped_protocol) + + # Get any pending data from the http buf and forward it to the original protocol + buf = self.http_setup_client.clearLineBuffer() + if buf: + self.wrapped_protocol.dataReceived(buf) + + def dataReceived(self, data): + # if we've set up the HTTP protocol, we can send the data there + if self.wrapped_protocol.connected: + return self.wrapped_protocol.dataReceived(data) + + # otherwise, we must still be setting up the connection: send the data to the + # setup client + return self.http_setup_client.dataReceived(data) + + +class HTTPConnectSetupClient(http.HTTPClient): + """HTTPClient protocol to send a CONNECT message for proxies and read the response. + + Args: + host (bytes): The hostname to send in the CONNECT message + port (int): The port to send in the CONNECT message + """ + + def __init__(self, host, port): + self.host = host + self.port = port + self.on_connected = defer.Deferred() + + def connectionMade(self): + logger.debug("Connected to proxy, sending CONNECT") + self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) + self.endHeaders() + + def handleStatus(self, version, status, message): + logger.debug("Got Status: %s %s %s", status, message, version) + if status != b"200": + raise ProxyConnectError("Unexpected status on CONNECT: %s" % status) + + def handleEndHeaders(self): + logger.debug("End Headers") + self.on_connected.callback(None) + + def handleResponse(self, body): + pass diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py new file mode 100644 index 0000000000..332da02a8d --- /dev/null +++ b/synapse/http/proxyagent.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re + +from zope.interface import implementer + +from twisted.internet import defer +from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS +from twisted.python.failure import Failure +from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase +from twisted.web.error import SchemeNotSupported +from twisted.web.iweb import IAgent + +from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint + +logger = logging.getLogger(__name__) + +_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z") + + +@implementer(IAgent) +class ProxyAgent(_AgentBase): + """An Agent implementation which will use an HTTP proxy if one was requested + + Args: + reactor: twisted reactor to place outgoing + connections. + + contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the + verification parameters of OpenSSL. The default is to use a + `BrowserLikePolicyForHTTPS`, so unless you have special + requirements you can leave this as-is. + + connectTimeout (float): The amount of time that this Agent will wait + for the peer to accept a connection. + + bindAddress (bytes): The local address for client sockets to bind to. + + pool (HTTPConnectionPool|None): connection pool to be used. If None, a + non-persistent pool instance will be created. + """ + + def __init__( + self, + reactor, + contextFactory=BrowserLikePolicyForHTTPS(), + connectTimeout=None, + bindAddress=None, + pool=None, + http_proxy=None, + https_proxy=None, + ): + _AgentBase.__init__(self, reactor, pool) + + self._endpoint_kwargs = {} + if connectTimeout is not None: + self._endpoint_kwargs["timeout"] = connectTimeout + if bindAddress is not None: + self._endpoint_kwargs["bindAddress"] = bindAddress + + self.http_proxy_endpoint = _http_proxy_endpoint( + http_proxy, reactor, **self._endpoint_kwargs + ) + + self.https_proxy_endpoint = _http_proxy_endpoint( + https_proxy, reactor, **self._endpoint_kwargs + ) + + self._policy_for_https = contextFactory + self._reactor = reactor + + def request(self, method, uri, headers=None, bodyProducer=None): + """ + Issue a request to the server indicated by the given uri. + + Supports `http` and `https` schemes. + + An existing connection from the connection pool may be used or a new one may be + created. + + See also: twisted.web.iweb.IAgent.request + + Args: + method (bytes): The request method to use, such as `GET`, `POST`, etc + + uri (bytes): The location of the resource to request. + + headers (Headers|None): Extra headers to send with the request + + bodyProducer (IBodyProducer|None): An object which can generate bytes to + make up the body of this request (for example, the properly encoded + contents of a file for a file upload). Or, None if the request is to + have no body. + + Returns: + Deferred[IResponse]: completes when the header of the response has + been received (regardless of the response status code). + """ + uri = uri.strip() + if not _VALID_URI.match(uri): + raise ValueError("Invalid URI {!r}".format(uri)) + + parsed_uri = URI.fromBytes(uri) + pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) + request_path = parsed_uri.originForm + + if parsed_uri.scheme == b"http" and self.http_proxy_endpoint: + # Cache *all* connections under the same key, since we are only + # connecting to a single destination, the proxy: + pool_key = ("http-proxy", self.http_proxy_endpoint) + endpoint = self.http_proxy_endpoint + request_path = uri + elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint: + endpoint = HTTPConnectProxyEndpoint( + self._reactor, + self.https_proxy_endpoint, + parsed_uri.host, + parsed_uri.port, + ) + else: + # not using a proxy + endpoint = HostnameEndpoint( + self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs + ) + + logger.debug("Requesting %s via %s", uri, endpoint) + + if parsed_uri.scheme == b"https": + tls_connection_creator = self._policy_for_https.creatorForNetloc( + parsed_uri.host, parsed_uri.port + ) + endpoint = wrapClientTLS(tls_connection_creator, endpoint) + elif parsed_uri.scheme == b"http": + pass + else: + return defer.fail( + Failure( + SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,)) + ) + ) + + return self._requestWithEndpoint( + pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path + ) + + +def _http_proxy_endpoint(proxy, reactor, **kwargs): + """Parses an http proxy setting and returns an endpoint for the proxy + + Args: + proxy (bytes|None): the proxy setting + reactor: reactor to be used to connect to the proxy + kwargs: other args to be passed to HostnameEndpoint + + Returns: + interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy, + or None + """ + if proxy is None: + return None + + # currently we only support hostname:port. Some apps also support + # protocol://[:port], which allows a way of requiring a TLS connection to the + # proxy. + + host, port = parse_host_port(proxy, default_port=1080) + return HostnameEndpoint(reactor, host, port, **kwargs) + + +def parse_host_port(hostport, default_port=None): + # could have sworn we had one of these somewhere else... + if b":" in hostport: + host, port = hostport.rsplit(b":", 1) + try: + port = int(port) + return host, port + except ValueError: + # the thing after the : wasn't a valid port; presumably this is an + # IPv6 address. + pass + + return hostport, default_port diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 7dde2ad055..e994037be6 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -103,7 +103,7 @@ class HttpPusher(object): if "url" not in self.data: raise PusherConfigException("'url' required in data for HTTP pusher") self.url = self.data["url"] - self.http_client = hs.get_simple_http_client() + self.http_client = hs.get_proxied_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url["url"] diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 00a7dd6d09..24a0ce74f2 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -381,7 +381,7 @@ class CasTicketServlet(RestServlet): self.cas_displayname_attribute = hs.config.cas_displayname_attribute self.cas_required_attributes = hs.config.cas_required_attributes self._sso_auth_handler = SSOAuthHandler(hs) - self._http_client = hs.get_simple_http_client() + self._http_client = hs.get_proxied_http_client() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 5a25b6b3fc..531d923f76 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -74,6 +74,8 @@ class PreviewUrlResource(DirectServeResource): treq_args={"browser_like_redirects": True}, ip_whitelist=hs.config.url_preview_ip_range_whitelist, ip_blacklist=hs.config.url_preview_ip_range_blacklist, + http_proxy=os.getenv("http_proxy"), + https_proxy=os.getenv("HTTPS_PROXY"), ) self.media_repo = media_repo self.primary_base_path = media_repo.primary_base_path diff --git a/synapse/server.py b/synapse/server.py index 0b81af646c..f8aeebcff8 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -23,6 +23,7 @@ # Imports required for the default HomeServer() implementation import abc import logging +import os from twisted.enterprise import adbapi from twisted.mail.smtp import sendmail @@ -168,6 +169,7 @@ class HomeServer(object): "filtering", "http_client_context_factory", "simple_http_client", + "proxied_http_client", "media_repository", "media_repository_resource", "federation_transport_client", @@ -311,6 +313,13 @@ class HomeServer(object): def build_simple_http_client(self): return SimpleHttpClient(self) + def build_proxied_http_client(self): + return SimpleHttpClient( + self, + http_proxy=os.getenv("http_proxy"), + https_proxy=os.getenv("HTTPS_PROXY"), + ) + def build_room_creation_handler(self): return RoomCreationHandler(self) diff --git a/synapse/server.pyi b/synapse/server.pyi index 83d1f11283..b5e0b57095 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -12,6 +12,7 @@ import synapse.handlers.message import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password +import synapse.http.client import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender @@ -38,6 +39,14 @@ class HomeServer(object): pass def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: pass + def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient: + """Fetch an HTTP client implementation which doesn't do any blacklisting + or support any HTTP_PROXY settings""" + pass + def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient: + """Fetch an HTTP client implementation which doesn't do any blacklisting + but does support HTTP_PROXY settings""" + pass def get_deactivate_account_handler( self, ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: diff --git a/tests/http/__init__.py b/tests/http/__init__.py index 2d5dba6464..2096ba3c91 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -20,6 +20,23 @@ from zope.interface import implementer from OpenSSL import SSL from OpenSSL.SSL import Connection from twisted.internet.interfaces import IOpenSSLServerConnectionCreator +from twisted.internet.ssl import Certificate, trustRootFromCertificates +from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401 +from twisted.web.iweb import IPolicyForHTTPS # noqa: F401 + + +def get_test_https_policy(): + """Get a test IPolicyForHTTPS which trusts the test CA cert + + Returns: + IPolicyForHTTPS + """ + ca_file = get_test_ca_cert_file() + with open(ca_file) as stream: + content = stream.read() + cert = Certificate.loadPEM(content) + trust_root = trustRootFromCertificates([cert]) + return BrowserLikePolicyForHTTPS(trustRoot=trust_root) def get_test_ca_cert_file(): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 71d7025264..cfcd98ff7d 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase): FakeTransport(client_protocol, self.reactor, server_tls_protocol) ) + # grab a hold of the TLS connection, in case it gets torn down + server_tls_connection = server_tls_protocol._tlsConnection + + # fish the test server back out of the server-side TLS protocol. + http_protocol = server_tls_protocol.wrappedProtocol + # give the reactor a pump to get the TLS juices flowing. self.reactor.pump((0.1,)) # check the SNI - server_name = server_tls_protocol._tlsConnection.get_servername() + server_name = server_tls_connection.get_servername() self.assertEqual( server_name, expected_sni, "Expected SNI %s but got %s" % (expected_sni, server_name), ) - # fish the test server back out of the server-side TLS protocol. - return server_tls_protocol.wrappedProtocol + return http_protocol @defer.inlineCallbacks def _make_get_request(self, uri): diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py new file mode 100644 index 0000000000..22abf76515 --- /dev/null +++ b/tests/http/test_proxyagent.py @@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import treq + +from twisted.internet import interfaces # noqa: F401 +from twisted.internet.protocol import Factory +from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.web.http import HTTPChannel + +from synapse.http.proxyagent import ProxyAgent + +from tests.http import TestServerTLSConnectionFactory, get_test_https_policy +from tests.server import FakeTransport, ThreadedMemoryReactorClock +from tests.unittest import TestCase + +logger = logging.getLogger(__name__) + +HTTPFactory = Factory.forProtocol(HTTPChannel) + + +class MatrixFederationAgentTests(TestCase): + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + + def _make_connection( + self, client_factory, server_factory, ssl=False, expected_sni=None + ): + """Builds a test server, and completes the outgoing client connection + + Args: + client_factory (interfaces.IProtocolFactory): the the factory that the + application is trying to use to make the outbound connection. We will + invoke it to build the client Protocol + + server_factory (interfaces.IProtocolFactory): a factory to build the + server-side protocol + + ssl (bool): If true, we will expect an ssl connection and wrap + server_factory with a TLSMemoryBIOFactory + + expected_sni (bytes|None): the expected SNI value + + Returns: + IProtocol: the server Protocol returned by server_factory + """ + if ssl: + server_factory = _wrap_server_factory_for_tls(server_factory) + + server_protocol = server_factory.buildProtocol(None) + + # now, tell the client protocol factory to build the client protocol, + # and wire the output of said protocol up to the server via + # a FakeTransport. + # + # Normally this would be done by the TCP socket code in Twisted, but we are + # stubbing that out here. + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection( + FakeTransport(server_protocol, self.reactor, client_protocol) + ) + + # tell the server protocol to send its stuff back to the client, too + server_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_protocol) + ) + + if ssl: + http_protocol = server_protocol.wrappedProtocol + tls_connection = server_protocol._tlsConnection + else: + http_protocol = server_protocol + tls_connection = None + + # give the reactor a pump to get the TLS juices flowing (if needed) + self.reactor.advance(0) + + if expected_sni is not None: + server_name = tls_connection.get_servername() + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + + return http_protocol + + def test_http_request(self): + agent = ProxyAgent(self.reactor) + + self.reactor.lookups["test.com"] = "1.2.3.4" + d = agent.request(b"GET", b"http://test.com") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 80) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_https_request(self): + agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy()) + + self.reactor.lookups["test.com"] = "1.2.3.4" + d = agent.request(b"GET", b"https://test.com/abc") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 443) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, + _get_test_protocol_factory(), + ssl=True, + expected_sni=b"test.com", + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/abc") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_http_request_via_proxy(self): + agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888") + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"http://test.com") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 8888) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"http://test.com") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_https_request_via_proxy(self): + agent = ProxyAgent( + self.reactor, + contextFactory=get_test_https_policy(), + https_proxy=b"proxy.com", + ) + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"https://test.com/abc") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 1080) + + # make a test HTTP server, and wire up the client + proxy_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # fish the transports back out so that we can do the old switcheroo + s2c_transport = proxy_server.transport + client_protocol = s2c_transport.other + c2s_transport = client_protocol.transport + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending CONNECT request + self.assertEqual(len(proxy_server.requests), 1) + + request = proxy_server.requests[0] + self.assertEqual(request.method, b"CONNECT") + self.assertEqual(request.path, b"test.com:443") + + # tell the proxy server not to close the connection + proxy_server.persistent = True + + # this just stops the http Request trying to do a chunked response + # request.setHeader(b"Content-Length", b"0") + request.finish() + + # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel + ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory()) + ssl_protocol = ssl_factory.buildProtocol(None) + http_server = ssl_protocol.wrappedProtocol + + ssl_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, ssl_protocol) + ) + c2s_transport.other = ssl_protocol + + self.reactor.advance(0) + + server_name = ssl_protocol._tlsConnection.get_servername() + expected_sni = b"test.com" + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/abc") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + +def _wrap_server_factory_for_tls(factory, sanlist=None): + """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory + + The resultant factory will create a TLS server which presents a certificate + signed by our test CA, valid for the domains in `sanlist` + + Args: + factory (interfaces.IProtocolFactory): protocol factory to wrap + sanlist (iterable[bytes]): list of domains the cert should be valid for + + Returns: + interfaces.IProtocolFactory + """ + if sanlist is None: + sanlist = [b"DNS:test.com"] + + connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist) + return TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=factory + ) + + +def _get_test_protocol_factory(): + """Get a protocol Factory which will build an HTTPChannel + + Returns: + interfaces.IProtocolFactory + """ + server_factory = Factory.forProtocol(HTTPChannel) + + # Request.finish expects the factory to have a 'log' method. + server_factory.log = _log_request + + return server_factory + + +def _log_request(request): + """Implements Factory.log, which is expected by Request.finish""" + logger.info("Completed request %s", request) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 8ce6bb62da..af2327fb66 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase): config = self.default_config() config["start_pushers"] = True - hs = self.setup_test_homeserver(config=config, simple_http_client=m) + hs = self.setup_test_homeserver(config=config, proxied_http_client=m) return hs diff --git a/tests/server.py b/tests/server.py index 469efb4edb..f878aeaada 100644 --- a/tests/server.py +++ b/tests/server.py @@ -395,11 +395,24 @@ class FakeTransport(object): self.disconnecting = True if self._protocol: self._protocol.connectionLost(reason) - self.disconnected = True + + # if we still have data to write, delay until that is done + if self.buffer: + logger.info( + "FakeTransport: Delaying disconnect until buffer is flushed" + ) + else: + self.disconnected = True def abortConnection(self): logger.info("FakeTransport: abortConnection()") - self.loseConnection() + + if not self.disconnecting: + self.disconnecting = True + if self._protocol: + self._protocol.connectionLost(None) + + self.disconnected = True def pauseProducing(self): if not self.producer: @@ -430,6 +443,9 @@ class FakeTransport(object): self._reactor.callLater(0.0, _produce) def write(self, byt): + if self.disconnecting: + raise Exception("Writing to disconnecting FakeTransport") + self.buffer = self.buffer + byt # always actually do the write asynchronously. Some protocols (notably the @@ -474,6 +490,10 @@ class FakeTransport(object): if self.buffer and self.autoflush: self._reactor.callLater(0.0, self.flush) + if not self.buffer and self.disconnecting: + logger.info("FakeTransport: Buffer now empty, completing disconnect") + self.disconnected = True + def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol: """ -- cgit 1.4.1 From a8a50f5b5746279379b4511c8ecb2a40b143fe32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 Jan 2020 10:27:19 +0000 Subject: Wake up transaction queue when remote server comes back online (#6706) This will be used to retry outbound transactions to a remote server if we think it might have come back up. --- changelog.d/6706.misc | 1 + docs/tcp_replication.md | 6 +++++- synapse/app/federation_sender.py | 12 +++++++++++- synapse/federation/sender/__init__.py | 18 ++++++++++++++++-- synapse/federation/transport/server.py | 19 ++++++++++++++++++- synapse/notifier.py | 31 ++++++++++++++++++++++++++++--- synapse/replication/tcp/client.py | 3 +++ synapse/replication/tcp/commands.py | 17 +++++++++++++++++ synapse/replication/tcp/protocol.py | 15 +++++++++++++++ synapse/replication/tcp/resource.py | 9 +++++++++ synapse/server.pyi | 12 ++++++++++++ 11 files changed, 135 insertions(+), 8 deletions(-) create mode 100644 changelog.d/6706.misc (limited to 'synapse/server.pyi') diff --git a/changelog.d/6706.misc b/changelog.d/6706.misc new file mode 100644 index 0000000000..1ac11cc04b --- /dev/null +++ b/changelog.d/6706.misc @@ -0,0 +1 @@ +Attempt to retry sending a transaction when we detect a remote server has come back online, rather than waiting for a transaction to be triggered by new data. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index ba9e874d07..a0b1d563ff 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -209,7 +209,7 @@ Where `` may be either: * a numeric stream_id to stream updates since (exclusive) * `NOW` to stream all subsequent updates. -The `` is the name of a replication stream to subscribe +The `` is the name of a replication stream to subscribe to (see [here](../synapse/replication/tcp/streams/_base.py) for a list of streams). It can also be `ALL` to subscribe to all known streams, in which case the `` must be set to `NOW`. @@ -234,6 +234,10 @@ in which case the `` must be set to `NOW`. Used exclusively in tests +### REMOTE_SERVER_UP (S, C) + + Inform other processes that a remote server may have come back online. + See `synapse/replication/tcp/commands.py` for a detailed description and the format of each command. diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index a57cf991ac..38d11fdd0f 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -158,6 +158,13 @@ class FederationSenderReplicationHandler(ReplicationClientHandler): args.update(self.send_handler.stream_positions()) return args + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + + # Let's wake up the transaction queue for the server in case we have + # pending stuff to send to it. + self.send_handler.wake_destination(server) + def start(config_options): try: @@ -205,7 +212,7 @@ class FederationSenderHandler(object): to the federation sender. """ - def __init__(self, hs, replication_client): + def __init__(self, hs: FederationSenderServer, replication_client): self.store = hs.get_datastore() self._is_mine_id = hs.is_mine_id self.federation_sender = hs.get_federation_sender() @@ -226,6 +233,9 @@ class FederationSenderHandler(object): self.store.get_room_max_stream_ordering() ) + def wake_destination(self, server: str): + self.federation_sender.wake_destination(server) + def stream_positions(self): return {"federation": self.federation_position} diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 4ebb0e8bc0..36c83c3027 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -21,6 +21,7 @@ from prometheus_client import Counter from twisted.internet import defer +import synapse import synapse.metrics from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.transaction_manager import TransactionManager @@ -54,7 +55,7 @@ sent_pdus_destination_dist_total = Counter( class FederationSender(object): - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.hs = hs self.server_name = hs.hostname @@ -482,7 +483,20 @@ class FederationSender(object): def send_device_messages(self, destination): if destination == self.server_name: - logger.info("Not sending device update to ourselves") + logger.warning("Not sending device update to ourselves") + return + + self._get_per_destination_queue(destination).attempt_new_transaction() + + def wake_destination(self, destination: str): + """Called when we want to retry sending transactions to a remote. + + This is mainly useful if the remote server has been down and we think it + might have come back. + """ + + if destination == self.server_name: + logger.warning("Not waking up ourselves") return self._get_per_destination_queue(destination).attempt_new_transaction() diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index b4cbf23394..d8cf9ed299 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -44,6 +44,7 @@ from synapse.logging.opentracing import ( tags, whitelisted_homeserver, ) +from synapse.server import HomeServer from synapse.types import ThirdPartyInstanceID, get_domain_from_id from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string @@ -101,12 +102,17 @@ class NoAuthenticationError(AuthenticationError): class Authenticator(object): - def __init__(self, hs): + def __init__(self, hs: HomeServer): self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname self.store = hs.get_datastore() self.federation_domain_whitelist = hs.config.federation_domain_whitelist + self.notifer = hs.get_notifier() + + self.replication_client = None + if hs.config.worker.worker_app: + self.replication_client = hs.get_tcp_replication() # A method just so we can pass 'self' as the authenticator to the Servlets async def authenticate_request(self, request, content): @@ -166,6 +172,17 @@ class Authenticator(object): try: logger.info("Marking origin %r as up", origin) await self.store.set_destination_retry_timings(origin, None, 0, 0) + + # Inform the relevant places that the remote server is back up. + self.notifer.notify_remote_server_up(origin) + if self.replication_client: + # If we're on a worker we try and inform master about this. The + # replication client doesn't hook into the notifier to avoid + # infinite loops where we send a `REMOTE_SERVER_UP` command to + # master, which then echoes it back to us which in turn pokes + # the notifier. + self.replication_client.send_remote_server_up(origin) + except Exception: logger.exception("Error resetting retry timings on %s", origin) diff --git a/synapse/notifier.py b/synapse/notifier.py index 5f5f765bea..6132727cbd 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -15,11 +15,13 @@ import logging from collections import namedtuple +from typing import Callable, List from prometheus_client import Counter from twisted.internet import defer +import synapse.server from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state @@ -154,7 +156,7 @@ class Notifier(object): UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.user_to_user_stream = {} self.room_to_user_streams = {} @@ -164,7 +166,12 @@ class Notifier(object): self.store = hs.get_datastore() self.pending_new_room_events = [] - self.replication_callbacks = [] + # Called when there are new things to stream over replication + self.replication_callbacks = [] # type: List[Callable[[], None]] + + # Called when remote servers have come back online after having been + # down. + self.remote_server_up_callbacks = [] # type: List[Callable[[str], None]] self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() @@ -205,7 +212,7 @@ class Notifier(object): "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream) ) - def add_replication_callback(self, cb): + def add_replication_callback(self, cb: Callable[[], None]): """Add a callback that will be called when some new data is available. Callback is not given any arguments. It should *not* return a Deferred - if it needs to do any asynchronous work, a background thread should be started and @@ -213,6 +220,12 @@ class Notifier(object): """ self.replication_callbacks.append(cb) + def add_remote_server_up_callback(self, cb: Callable[[str], None]): + """Add a callback that will be called when synapse detects a server + has been + """ + self.remote_server_up_callbacks.append(cb) + def on_new_room_event( self, event, room_stream_id, max_room_stream_id, extra_users=[] ): @@ -522,3 +535,15 @@ class Notifier(object): """Notify the any replication listeners that there's a new event""" for cb in self.replication_callbacks: cb() + + def notify_remote_server_up(self, server: str): + """Notify any replication that a remote server has come back up + """ + # We call federation_sender directly rather than registering as a + # callback as a) we already have a reference to it and b) it introduces + # circular dependencies. + if self.federation_sender: + self.federation_sender.wake_destination(server) + + for cb in self.remote_server_up_callbacks: + cb(server) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 52a0aefe68..fc06a7b053 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -143,6 +143,9 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): if d: d.callback(data) + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + def get_streams_to_replicate(self) -> Dict[str, int]: """Called when a new connection has been established and we need to subscribe to streams. diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index cbb36b9acf..451671412d 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -387,6 +387,20 @@ class UserIpCommand(Command): ) +class RemoteServerUpCommand(Command): + """Sent when a worker has detected that a remote server is no longer + "down" and retry timings should be reset. + + If sent from a client the server will relay to all other workers. + + Format:: + + REMOTE_SERVER_UP + """ + + NAME = "REMOTE_SERVER_UP" + + _COMMANDS = ( ServerCommand, RdataCommand, @@ -401,6 +415,7 @@ _COMMANDS = ( RemovePusherCommand, InvalidateCacheCommand, UserIpCommand, + RemoteServerUpCommand, ) # type: Tuple[Type[Command], ...] # Map of command name to command type. @@ -414,6 +429,7 @@ VALID_SERVER_COMMANDS = ( ErrorCommand.NAME, PingCommand.NAME, SyncCommand.NAME, + RemoteServerUpCommand.NAME, ) # The commands the client is allowed to send @@ -427,4 +443,5 @@ VALID_CLIENT_COMMANDS = ( InvalidateCacheCommand.NAME, UserIpCommand.NAME, ErrorCommand.NAME, + RemoteServerUpCommand.NAME, ) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 5f4bdf84d2..131e5acb09 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -76,6 +76,7 @@ from synapse.replication.tcp.commands import ( PingCommand, PositionCommand, RdataCommand, + RemoteServerUpCommand, ReplicateCommand, ServerCommand, SyncCommand, @@ -460,6 +461,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): async def on_INVALIDATE_CACHE(self, cmd): self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + self.streamer.on_remote_server_up(cmd.data) + async def on_USER_IP(self, cmd): self.streamer.on_user_ip( cmd.user_id, @@ -555,6 +559,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): def send_sync(self, data): self.send_command(SyncCommand(data)) + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) + def on_connection_closed(self): BaseReplicationStreamProtocol.on_connection_closed(self) self.streamer.lost_connection(self) @@ -588,6 +595,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): """Called when get a new SYNC command.""" raise NotImplementedError() + @abc.abstractmethod + async def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + raise NotImplementedError() + @abc.abstractmethod def get_streams_to_replicate(self): """Called when a new connection has been established and we need to @@ -707,6 +719,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): async def on_SYNC(self, cmd): self.handler.on_sync(cmd.data) + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + self.handler.on_remote_server_up(cmd.data) + def replicate(self, stream_name, token): """Send the subscription request to the server """ diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index b1752e88cd..6ebf944f66 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -120,6 +120,7 @@ class ReplicationStreamer(object): self.federation_sender = hs.get_federation_sender() self.notifier.add_replication_callback(self.on_notifier_poke) + self.notifier.add_remote_server_up_callback(self.send_remote_server_up) # Keeps track of whether we are currently checking for updates self.is_looping = False @@ -288,6 +289,14 @@ class ReplicationStreamer(object): ) await self._server_notices_sender.on_user_ip(user_id) + @measure_func("repl.on_remote_server_up") + def on_remote_server_up(self, server: str): + self.notifier.notify_remote_server_up(server) + + def send_remote_server_up(self, server: str): + for conn in self.connections: + conn.send_remote_server_up(server) + def send_sync_to_all_connections(self, data): """Sends a SYNC command to all clients. diff --git a/synapse/server.pyi b/synapse/server.pyi index b5e0b57095..0731403047 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,3 +1,5 @@ +import twisted.internet + import synapse.api.auth import synapse.config.homeserver import synapse.federation.sender @@ -9,10 +11,12 @@ import synapse.handlers.deactivate_account import synapse.handlers.device import synapse.handlers.e2e_keys import synapse.handlers.message +import synapse.handlers.presence import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password import synapse.http.client +import synapse.notifier import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender @@ -85,3 +89,11 @@ class HomeServer(object): self, ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: pass + def get_notifier(self) -> synapse.notifier.Notifier: + pass + def get_presence_handler(self) -> synapse.handlers.presence.PresenceHandler: + pass + def get_clock(self) -> synapse.util.Clock: + pass + def get_reactor(self) -> twisted.internet.base.ReactorBase: + pass -- cgit 1.4.1 From c3d4ad8afdbe181707451410100dec4817c2c01a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 Jan 2020 16:42:11 +0000 Subject: Fix sending server up commands from workers (#6811) Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/6811.bugfix | 1 + synapse/federation/transport/client.py | 5 ++++- synapse/federation/transport/server.py | 26 +++++++++++++++----------- synapse/replication/tcp/client.py | 4 ++++ synapse/server.pyi | 12 +++++++++++- tox.ini | 1 + 6 files changed, 36 insertions(+), 13 deletions(-) create mode 100644 changelog.d/6811.bugfix (limited to 'synapse/server.pyi') diff --git a/changelog.d/6811.bugfix b/changelog.d/6811.bugfix new file mode 100644 index 0000000000..361f2fc2e8 --- /dev/null +++ b/changelog.d/6811.bugfix @@ -0,0 +1 @@ +Fix waking up other workers when remote server is detected to have come back online. diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 198257414b..dc563538de 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import Any, Dict from six.moves import urllib @@ -352,7 +353,9 @@ class TransportLayerClient(object): else: path = _create_v1_path("/publicRooms") - args = {"include_all_networks": "true" if include_all_networks else "false"} + args = { + "include_all_networks": "true" if include_all_networks else "false" + } # type: Dict[str, Any] if third_party_instance_id: args["third_party_instance_id"] = (third_party_instance_id,) if limit: diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index d8cf9ed299..125eadd796 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -18,6 +18,7 @@ import functools import logging import re +from typing import Optional, Tuple, Type from twisted.internet.defer import maybeDeferred @@ -267,6 +268,8 @@ class BaseFederationServlet(object): returned. """ + PATH = "" # Overridden in subclasses, the regex to match against the path. + REQUIRE_AUTH = True PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version @@ -347,9 +350,6 @@ class BaseFederationServlet(object): return response - # Extra logic that functools.wraps() doesn't finish - new_func.__self__ = func.__self__ - return new_func def register(self, server): @@ -824,7 +824,7 @@ class PublicRoomList(BaseFederationServlet): if not self.allow_access: raise FederationDeniedError(origin) - limit = int(content.get("limit", 100)) + limit = int(content.get("limit", 100)) # type: Optional[int] since_token = content.get("since", None) search_filter = content.get("filter", None) @@ -971,7 +971,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - result = await self.groups_handler.update_room_in_group( + result = await self.handler.update_room_in_group( group_id, requester_user_id, room_id, config_key, content ) @@ -1422,11 +1422,13 @@ FEDERATION_SERVLET_CLASSES = ( On3pidBindServlet, FederationVersionServlet, RoomComplexityServlet, -) +) # type: Tuple[Type[BaseFederationServlet], ...] -OPENID_SERVLET_CLASSES = (OpenIdUserInfo,) +OPENID_SERVLET_CLASSES = ( + OpenIdUserInfo, +) # type: Tuple[Type[BaseFederationServlet], ...] -ROOM_LIST_CLASSES = (PublicRoomList,) +ROOM_LIST_CLASSES = (PublicRoomList,) # type: Tuple[Type[PublicRoomList], ...] GROUP_SERVER_SERVLET_CLASSES = ( FederationGroupsProfileServlet, @@ -1447,17 +1449,19 @@ GROUP_SERVER_SERVLET_CLASSES = ( FederationGroupsAddRoomsServlet, FederationGroupsAddRoomsConfigServlet, FederationGroupsSettingJoinPolicyServlet, -) +) # type: Tuple[Type[BaseFederationServlet], ...] GROUP_LOCAL_SERVLET_CLASSES = ( FederationGroupsLocalInviteServlet, FederationGroupsRemoveLocalUserServlet, FederationGroupsBulkPublicisedServlet, -) +) # type: Tuple[Type[BaseFederationServlet], ...] -GROUP_ATTESTATION_SERVLET_CLASSES = (FederationGroupsRenewAttestaionServlet,) +GROUP_ATTESTATION_SERVLET_CLASSES = ( + FederationGroupsRenewAttestaionServlet, +) # type: Tuple[Type[BaseFederationServlet], ...] DEFAULT_SERVLET_GROUPS = ( "federation", diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index fc06a7b053..02ab5b66ea 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -31,6 +31,7 @@ from .commands import ( Command, FederationAckCommand, InvalidateCacheCommand, + RemoteServerUpCommand, RemovePusherCommand, UserIpCommand, UserSyncCommand, @@ -210,6 +211,9 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) self.send_command(cmd) + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) + def await_sync(self, data): """Returns a deferred that is resolved when we receive a SYNC command with given data. diff --git a/synapse/server.pyi b/synapse/server.pyi index 0731403047..90347ac23e 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -2,8 +2,8 @@ import twisted.internet import synapse.api.auth import synapse.config.homeserver +import synapse.crypto.keyring import synapse.federation.sender -import synapse.federation.transaction_queue import synapse.federation.transport.client import synapse.handlers import synapse.handlers.auth @@ -17,6 +17,7 @@ import synapse.handlers.room_member import synapse.handlers.set_password import synapse.http.client import synapse.notifier +import synapse.replication.tcp.client import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender @@ -27,6 +28,9 @@ class HomeServer(object): @property def config(self) -> synapse.config.homeserver.HomeServerConfig: pass + @property + def hostname(self) -> str: + pass def get_auth(self) -> synapse.api.auth.Auth: pass def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: @@ -97,3 +101,9 @@ class HomeServer(object): pass def get_reactor(self) -> twisted.internet.base.ReactorBase: pass + def get_keyring(self) -> synapse.crypto.keyring.Keyring: + pass + def get_tcp_replication( + self, + ) -> synapse.replication.tcp.client.ReplicationClientHandler: + pass diff --git a/tox.ini b/tox.ini index 1d946a02ba..88ef12bebd 100644 --- a/tox.ini +++ b/tox.ini @@ -179,6 +179,7 @@ extras = all commands = mypy \ synapse/api \ synapse/config/ \ + synapse/federation/transport \ synapse/handlers/ui_auth \ synapse/logging/ \ synapse/module_api \ -- cgit 1.4.1 From b08b0a22d505b1555f511e3f38935a62930ea25d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Feb 2020 13:56:38 +0000 Subject: Add typing to synapse.federation.sender (#6871) --- changelog.d/6871.misc | 1 + synapse/federation/federation_server.py | 7 +- synapse/federation/sender/__init__.py | 99 +++++++++++----------- synapse/federation/sender/per_destination_queue.py | 88 +++++++++---------- synapse/federation/sender/transaction_manager.py | 16 ++-- synapse/federation/units.py | 23 ++++- synapse/server.pyi | 2 + tests/handlers/test_typing.py | 8 +- tox.ini | 1 + 9 files changed, 138 insertions(+), 107 deletions(-) create mode 100644 changelog.d/6871.misc (limited to 'synapse/server.pyi') diff --git a/changelog.d/6871.misc b/changelog.d/6871.misc new file mode 100644 index 0000000000..5161af9983 --- /dev/null +++ b/changelog.d/6871.misc @@ -0,0 +1 @@ +Add typing to `synapse.federation.sender` and port to async/await. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2489832a11..a6c966a393 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -294,7 +294,12 @@ class FederationServer(FederationBase): async def _process_edu(edu_dict): received_edus_counter.inc() - edu = Edu(**edu_dict) + edu = Edu( + origin=origin, + destination=self.server_name, + edu_type=edu_dict["edu_type"], + content=edu_dict["content"], + ) await self.registry.on_edu(edu.edu_type, origin, edu.content) await concurrently_execute( diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 36c83c3027..233cb33daf 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Dict, Hashable, Iterable, List, Optional, Set from six import itervalues @@ -23,6 +24,7 @@ from twisted.internet import defer import synapse import synapse.metrics +from synapse.events import EventBase from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.transaction_manager import TransactionManager from synapse.federation.units import Edu @@ -39,6 +41,8 @@ from synapse.metrics import ( events_processed_counter, ) from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.presence import UserPresenceState +from synapse.types import ReadReceipt from synapse.util.metrics import Measure, measure_func logger = logging.getLogger(__name__) @@ -68,7 +72,7 @@ class FederationSender(object): self._transaction_manager = TransactionManager(hs) # map from destination to PerDestinationQueue - self._per_destination_queues = {} # type: dict[str, PerDestinationQueue] + self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue] LaterGauge( "synapse_federation_transaction_queue_pending_destinations", @@ -84,7 +88,7 @@ class FederationSender(object): # Map of user_id -> UserPresenceState for all the pending presence # to be sent out by user_id. Entries here get processed and put in # pending_presence_by_dest - self.pending_presence = {} + self.pending_presence = {} # type: Dict[str, UserPresenceState] LaterGauge( "synapse_federation_transaction_queue_pending_pdus", @@ -116,20 +120,17 @@ class FederationSender(object): # and that there is a pending call to _flush_rrs_for_room in the system. self._queues_awaiting_rr_flush_by_room = ( {} - ) # type: dict[str, set[PerDestinationQueue]] + ) # type: Dict[str, Set[PerDestinationQueue]] self._rr_txn_interval_per_room_ms = ( - 1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second + 1000.0 / hs.config.federation_rr_transactions_per_room_per_second ) - def _get_per_destination_queue(self, destination): + def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination Args: - destination (str): server_name of remote server - - Returns: - PerDestinationQueue + destination: server_name of remote server """ queue = self._per_destination_queues.get(destination) if not queue: @@ -137,7 +138,7 @@ class FederationSender(object): self._per_destination_queues[destination] = queue return queue - def notify_new_events(self, current_id): + def notify_new_events(self, current_id: int) -> None: """This gets called when we have some new events we might want to send out to other servers. """ @@ -151,13 +152,12 @@ class FederationSender(object): "process_event_queue_for_federation", self._process_event_queue_loop ) - @defer.inlineCallbacks - def _process_event_queue_loop(self): + async def _process_event_queue_loop(self) -> None: try: self._is_processing = True while True: - last_token = yield self.store.get_federation_out_pos("events") - next_token, events = yield self.store.get_all_new_events_stream( + last_token = await self.store.get_federation_out_pos("events") + next_token, events = await self.store.get_all_new_events_stream( last_token, self._last_poked_id, limit=100 ) @@ -166,8 +166,7 @@ class FederationSender(object): if not events and next_token >= self._last_poked_id: break - @defer.inlineCallbacks - def handle_event(event): + async def handle_event(event: EventBase) -> None: # Only send events for this server. send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of() is_mine = self.is_mine_id(event.sender) @@ -184,7 +183,7 @@ class FederationSender(object): # Otherwise if the last member on a server in a room is # banned then it won't receive the event because it won't # be in the room after the ban. - destinations = yield self.state.get_hosts_in_room_at_events( + destinations = await self.state.get_hosts_in_room_at_events( event.room_id, event_ids=event.prev_event_ids() ) except Exception: @@ -206,17 +205,16 @@ class FederationSender(object): self._send_pdu(event, destinations) - @defer.inlineCallbacks - def handle_room_events(events): + async def handle_room_events(events: Iterable[EventBase]) -> None: with Measure(self.clock, "handle_room_events"): for event in events: - yield handle_event(event) + await handle_event(event) - events_by_room = {} + events_by_room = {} # type: Dict[str, List[EventBase]] for event in events: events_by_room.setdefault(event.room_id, []).append(event) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(handle_room_events, evs) @@ -226,11 +224,11 @@ class FederationSender(object): ) ) - yield self.store.update_federation_out_pos("events", next_token) + await self.store.update_federation_out_pos("events", next_token) if events: now = self.clock.time_msec() - ts = yield self.store.get_received_ts(events[-1].event_id) + ts = await self.store.get_received_ts(events[-1].event_id) synapse.metrics.event_processing_lag.labels( "federation_sender" @@ -254,7 +252,7 @@ class FederationSender(object): finally: self._is_processing = False - def _send_pdu(self, pdu, destinations): + def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None: # We loop through all destinations to see whether we already have # a transaction in progress. If we do, stick it in the pending_pdus # table and we'll get back to it later. @@ -276,11 +274,11 @@ class FederationSender(object): self._get_per_destination_queue(destination).send_pdu(pdu, order) @defer.inlineCallbacks - def send_read_receipt(self, receipt): + def send_read_receipt(self, receipt: ReadReceipt): """Send a RR to any other servers in the room Args: - receipt (synapse.types.ReadReceipt): receipt to be sent + receipt: receipt to be sent """ # Some background on the rate-limiting going on here. @@ -343,7 +341,7 @@ class FederationSender(object): else: queue.flush_read_receipts_for_room(room_id) - def _schedule_rr_flush_for_room(self, room_id, n_domains): + def _schedule_rr_flush_for_room(self, room_id: str, n_domains: int) -> None: # that is going to cause approximately len(domains) transactions, so now back # off for that multiplied by RR_TXN_INTERVAL_PER_ROOM backoff_ms = self._rr_txn_interval_per_room_ms * n_domains @@ -352,7 +350,7 @@ class FederationSender(object): self.clock.call_later(backoff_ms, self._flush_rrs_for_room, room_id) self._queues_awaiting_rr_flush_by_room[room_id] = set() - def _flush_rrs_for_room(self, room_id): + def _flush_rrs_for_room(self, room_id: str) -> None: queues = self._queues_awaiting_rr_flush_by_room.pop(room_id) logger.debug("Flushing RRs in %s to %s", room_id, queues) @@ -368,14 +366,11 @@ class FederationSender(object): @preserve_fn # the caller should not yield on this @defer.inlineCallbacks - def send_presence(self, states): + def send_presence(self, states: List[UserPresenceState]): """Send the new presence states to the appropriate destinations. This actually queues up the presence states ready for sending and triggers a background task to process them and send out the transactions. - - Args: - states (list(UserPresenceState)) """ if not self.hs.config.use_presence: # No-op if presence is disabled. @@ -412,11 +407,10 @@ class FederationSender(object): finally: self._processing_pending_presence = False - def send_presence_to_destinations(self, states, destinations): + def send_presence_to_destinations( + self, states: List[UserPresenceState], destinations: List[str] + ) -> None: """Send the given presence states to the given destinations. - - Args: - states (list[UserPresenceState]) destinations (list[str]) """ @@ -431,12 +425,9 @@ class FederationSender(object): @measure_func("txnqueue._process_presence") @defer.inlineCallbacks - def _process_presence_inner(self, states): + def _process_presence_inner(self, states: List[UserPresenceState]): """Given a list of states populate self.pending_presence_by_dest and poke to send a new transaction to each destination - - Args: - states (list(UserPresenceState)) """ hosts_and_states = yield get_interested_remotes(self.store, states, self.state) @@ -446,14 +437,20 @@ class FederationSender(object): continue self._get_per_destination_queue(destination).send_presence(states) - def build_and_send_edu(self, destination, edu_type, content, key=None): + def build_and_send_edu( + self, + destination: str, + edu_type: str, + content: dict, + key: Optional[Hashable] = None, + ): """Construct an Edu object, and queue it for sending Args: - destination (str): name of server to send to - edu_type (str): type of EDU to send - content (dict): content of EDU - key (Any|None): clobbering key for this edu + destination: name of server to send to + edu_type: type of EDU to send + content: content of EDU + key: clobbering key for this edu """ if destination == self.server_name: logger.info("Not sending EDU to ourselves") @@ -468,12 +465,12 @@ class FederationSender(object): self.send_edu(edu, key) - def send_edu(self, edu, key): + def send_edu(self, edu: Edu, key: Optional[Hashable]): """Queue an EDU for sending Args: - edu (Edu): edu to send - key (Any|None): clobbering key for this edu + edu: edu to send + key: clobbering key for this edu """ queue = self._get_per_destination_queue(edu.destination) if key: @@ -481,7 +478,7 @@ class FederationSender(object): else: queue.send_edu(edu) - def send_device_messages(self, destination): + def send_device_messages(self, destination: str): if destination == self.server_name: logger.warning("Not sending device update to ourselves") return @@ -501,5 +498,5 @@ class FederationSender(object): self._get_per_destination_queue(destination).attempt_new_transaction() - def get_current_token(self): + def get_current_token(self) -> int: return 0 diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 5012aaea35..e13cd20ffa 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -15,11 +15,11 @@ # limitations under the License. import datetime import logging +from typing import Dict, Hashable, Iterable, List, Tuple from prometheus_client import Counter -from twisted.internet import defer - +import synapse.server from synapse.api.errors import ( FederationDeniedError, HttpResponseException, @@ -31,7 +31,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState -from synapse.types import StateMap +from synapse.types import ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter # This is defined in the Matrix spec and enforced by the receiver. @@ -56,13 +56,18 @@ class PerDestinationQueue(object): Manages the per-destination transmission queues. Args: - hs (synapse.HomeServer): - transaction_sender (TransactionManager): - destination (str): the server_name of the destination that we are managing + hs + transaction_sender + destination: the server_name of the destination that we are managing transmission for. """ - def __init__(self, hs, transaction_manager, destination): + def __init__( + self, + hs: "synapse.server.HomeServer", + transaction_manager: "synapse.federation.sender.TransactionManager", + destination: str, + ): self._server_name = hs.hostname self._clock = hs.get_clock() self._store = hs.get_datastore() @@ -72,20 +77,20 @@ class PerDestinationQueue(object): self.transmission_loop_running = False # a list of tuples of (pending pdu, order) - self._pending_pdus = [] # type: list[tuple[EventBase, int]] - self._pending_edus = [] # type: list[Edu] + self._pending_pdus = [] # type: List[Tuple[EventBase, int]] + self._pending_edus = [] # type: List[Edu] # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered # based on their key (e.g. typing events by room_id) # Map of (edu_type, key) -> Edu - self._pending_edus_keyed = {} # type: StateMap[Edu] + self._pending_edus_keyed = {} # type: Dict[Tuple[str, Hashable], Edu] # Map of user_id -> UserPresenceState of pending presence to be sent to this # destination - self._pending_presence = {} # type: dict[str, UserPresenceState] + self._pending_presence = {} # type: Dict[str, UserPresenceState] # room_id -> receipt_type -> user_id -> receipt_dict - self._pending_rrs = {} + self._pending_rrs = {} # type: Dict[str, Dict[str, Dict[str, dict]]] self._rrs_pending_flush = False # stream_id of last successfully sent to-device message. @@ -95,50 +100,50 @@ class PerDestinationQueue(object): # stream_id of last successfully sent device list update. self._last_device_list_stream_id = 0 - def __str__(self): + def __str__(self) -> str: return "PerDestinationQueue[%s]" % self._destination - def pending_pdu_count(self): + def pending_pdu_count(self) -> int: return len(self._pending_pdus) - def pending_edu_count(self): + def pending_edu_count(self) -> int: return ( len(self._pending_edus) + len(self._pending_presence) + len(self._pending_edus_keyed) ) - def send_pdu(self, pdu, order): + def send_pdu(self, pdu: EventBase, order: int) -> None: """Add a PDU to the queue, and start the transmission loop if neccessary Args: - pdu (EventBase): pdu to send - order (int): + pdu: pdu to send + order """ self._pending_pdus.append((pdu, order)) self.attempt_new_transaction() - def send_presence(self, states): + def send_presence(self, states: Iterable[UserPresenceState]) -> None: """Add presence updates to the queue. Start the transmission loop if neccessary. Args: - states (iterable[UserPresenceState]): presence to send + states: presence to send """ self._pending_presence.update({state.user_id: state for state in states}) self.attempt_new_transaction() - def queue_read_receipt(self, receipt): + def queue_read_receipt(self, receipt: ReadReceipt) -> None: """Add a RR to the list to be sent. Doesn't start the transmission loop yet (see flush_read_receipts_for_room) Args: - receipt (synapse.api.receipt_info.ReceiptInfo): receipt to be queued + receipt: receipt to be queued """ self._pending_rrs.setdefault(receipt.room_id, {}).setdefault( receipt.receipt_type, {} )[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data} - def flush_read_receipts_for_room(self, room_id): + def flush_read_receipts_for_room(self, room_id: str) -> None: # if we don't have any read-receipts for this room, it may be that we've already # sent them out, so we don't need to flush. if room_id not in self._pending_rrs: @@ -146,15 +151,15 @@ class PerDestinationQueue(object): self._rrs_pending_flush = True self.attempt_new_transaction() - def send_keyed_edu(self, edu, key): + def send_keyed_edu(self, edu: Edu, key: Hashable) -> None: self._pending_edus_keyed[(edu.edu_type, key)] = edu self.attempt_new_transaction() - def send_edu(self, edu): + def send_edu(self, edu) -> None: self._pending_edus.append(edu) self.attempt_new_transaction() - def attempt_new_transaction(self): + def attempt_new_transaction(self) -> None: """Try to start a new transaction to this destination If there is already a transaction in progress to this destination, @@ -177,23 +182,22 @@ class PerDestinationQueue(object): self._transaction_transmission_loop, ) - @defer.inlineCallbacks - def _transaction_transmission_loop(self): - pending_pdus = [] + async def _transaction_transmission_loop(self) -> None: + pending_pdus = [] # type: List[Tuple[EventBase, int]] try: self.transmission_loop_running = True # This will throw if we wouldn't retry. We do this here so we fail # quickly, but we will later check this again in the http client, # hence why we throw the result away. - yield get_retry_limiter(self._destination, self._clock, self._store) + await get_retry_limiter(self._destination, self._clock, self._store) pending_pdus = [] while True: # We have to keep 2 free slots for presence and rr_edus limit = MAX_EDUS_PER_TRANSACTION - 2 - device_update_edus, dev_list_id = yield self._get_device_update_edus( + device_update_edus, dev_list_id = await self._get_device_update_edus( limit ) @@ -202,7 +206,7 @@ class PerDestinationQueue(object): ( to_device_edus, device_stream_id, - ) = yield self._get_to_device_message_edus(limit) + ) = await self._get_to_device_message_edus(limit) pending_edus = device_update_edus + to_device_edus @@ -269,7 +273,7 @@ class PerDestinationQueue(object): # END CRITICAL SECTION - success = yield self._transaction_manager.send_new_transaction( + success = await self._transaction_manager.send_new_transaction( self._destination, pending_pdus, pending_edus ) if success: @@ -280,7 +284,7 @@ class PerDestinationQueue(object): # Remove the acknowledged device messages from the database # Only bother if we actually sent some device messages if to_device_edus: - yield self._store.delete_device_msgs_for_remote( + await self._store.delete_device_msgs_for_remote( self._destination, device_stream_id ) @@ -289,7 +293,7 @@ class PerDestinationQueue(object): logger.info( "Marking as sent %r %r", self._destination, dev_list_id ) - yield self._store.mark_as_sent_devices_by_remote( + await self._store.mark_as_sent_devices_by_remote( self._destination, dev_list_id ) @@ -334,7 +338,7 @@ class PerDestinationQueue(object): # We want to be *very* sure we clear this after we stop processing self.transmission_loop_running = False - def _get_rr_edus(self, force_flush): + def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: if not self._pending_rrs: return if not force_flush and not self._rrs_pending_flush: @@ -351,17 +355,16 @@ class PerDestinationQueue(object): self._rrs_pending_flush = False yield edu - def _pop_pending_edus(self, limit): + def _pop_pending_edus(self, limit: int) -> List[Edu]: pending_edus = self._pending_edus pending_edus, self._pending_edus = pending_edus[:limit], pending_edus[limit:] return pending_edus - @defer.inlineCallbacks - def _get_device_update_edus(self, limit): + async def _get_device_update_edus(self, limit: int) -> Tuple[List[Edu], int]: last_device_list = self._last_device_list_stream_id # Retrieve list of new device updates to send to the destination - now_stream_id, results = yield self._store.get_device_updates_by_remote( + now_stream_id, results = await self._store.get_device_updates_by_remote( self._destination, last_device_list, limit=limit ) edus = [ @@ -378,11 +381,10 @@ class PerDestinationQueue(object): return (edus, now_stream_id) - @defer.inlineCallbacks - def _get_to_device_message_edus(self, limit): + async def _get_to_device_message_edus(self, limit: int) -> Tuple[List[Edu], int]: last_device_stream_id = self._last_device_stream_id to_device_stream_id = self._store.get_to_device_stream_token() - contents, stream_id = yield self._store.get_new_device_msgs_for_remote( + contents, stream_id = await self._store.get_new_device_msgs_for_remote( self._destination, last_device_stream_id, to_device_stream_id, limit ) edus = [ diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 5fed626d5b..3c2a02a3b3 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -13,14 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List from canonicaljson import json -from twisted.internet import defer - +import synapse.server from synapse.api.errors import HttpResponseException +from synapse.events import EventBase from synapse.federation.persistence import TransactionActions -from synapse.federation.units import Transaction +from synapse.federation.units import Edu, Transaction from synapse.logging.opentracing import ( extract_text_map, set_tag, @@ -39,7 +40,7 @@ class TransactionManager(object): shared between PerDestinationQueue objects """ - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self._server_name = hs.hostname self.clock = hs.get_clock() # nb must be called this for @measure_func self._store = hs.get_datastore() @@ -50,8 +51,9 @@ class TransactionManager(object): self._next_txn_id = int(self.clock.time_msec()) @measure_func("_send_new_transaction") - @defer.inlineCallbacks - def send_new_transaction(self, destination, pending_pdus, pending_edus): + async def send_new_transaction( + self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu] + ): # Make a transaction-sending opentracing span. This span follows on from # all the edus in that transaction. This needs to be done since there is @@ -127,7 +129,7 @@ class TransactionManager(object): return data try: - response = yield self._transport_layer.send_transaction( + response = await self._transport_layer.send_transaction( transaction, json_data_cb ) code = 200 diff --git a/synapse/federation/units.py b/synapse/federation/units.py index b4d743cde7..6b32e0dcbf 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -19,11 +19,15 @@ server protocol. import logging +import attr + +from synapse.types import JsonDict from synapse.util.jsonobject import JsonEncodedObject logger = logging.getLogger(__name__) +@attr.s(slots=True) class Edu(JsonEncodedObject): """ An Edu represents a piece of data sent from one homeserver to another. @@ -32,11 +36,24 @@ class Edu(JsonEncodedObject): internal ID or previous references graph. """ - valid_keys = ["origin", "destination", "edu_type", "content"] + edu_type = attr.ib(type=str) + content = attr.ib(type=dict) + origin = attr.ib(type=str) + destination = attr.ib(type=str) - required_keys = ["edu_type"] + def get_dict(self) -> JsonDict: + return { + "edu_type": self.edu_type, + "content": self.content, + } - internal_keys = ["origin", "destination"] + def get_internal_dict(self) -> JsonDict: + return { + "edu_type": self.edu_type, + "content": self.content, + "origin": self.origin, + "destination": self.destination, + } def get_context(self): return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}") diff --git a/synapse/server.pyi b/synapse/server.pyi index 90347ac23e..40eabfe5d9 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -107,3 +107,5 @@ class HomeServer(object): self, ) -> synapse.replication.tcp.client.ReplicationClientHandler: pass + def is_mine_id(self, domain_id: str) -> bool: + pass diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 68b9847bd2..2767b0497a 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -111,7 +111,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): retry_timings_res ) - self.datastore.get_device_updates_by_remote.return_value = (0, []) + self.datastore.get_device_updates_by_remote.return_value = defer.succeed( + (0, []) + ) def get_received_txn_response(*args): return defer.succeed(None) @@ -144,7 +146,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_current_state_deltas.return_value = (0, None) self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) + self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed( + ([], 0) + ) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( None diff --git a/tox.ini b/tox.ini index ef22368cf1..f8229eba88 100644 --- a/tox.ini +++ b/tox.ini @@ -179,6 +179,7 @@ extras = all commands = mypy \ synapse/api \ synapse/config/ \ + synapse/federation/sender \ synapse/federation/transport \ synapse/handlers/sync.py \ synapse/handlers/ui_auth \ -- cgit 1.4.1 From 1f773eec912e4908ab60f7823f5c0a024261af4d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 26 Feb 2020 15:33:26 +0000 Subject: Port PresenceHandler to async/await (#6991) --- changelog.d/6991.misc | 1 + synapse/handlers/message.py | 5 +- synapse/handlers/presence.py | 192 ++++++++++++++++-------------------- synapse/replication/tcp/resource.py | 6 +- synapse/server.pyi | 5 + tests/handlers/test_presence.py | 18 ++-- tox.ini | 1 + 7 files changed, 113 insertions(+), 115 deletions(-) create mode 100644 changelog.d/6991.misc (limited to 'synapse/server.pyi') diff --git a/changelog.d/6991.misc b/changelog.d/6991.misc new file mode 100644 index 0000000000..5130f4e8af --- /dev/null +++ b/changelog.d/6991.misc @@ -0,0 +1 @@ +Port `synapse.handlers.presence` to async/await. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d6be280952..a0103addd3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1016,11 +1016,10 @@ class EventCreationHandler(object): # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) - @defer.inlineCallbacks - def _bump_active_time(self, user): + async def _bump_active_time(self, user): try: presence = self.hs.get_presence_handler() - yield presence.bump_presence_active_time(user) + await presence.bump_presence_active_time(user) except Exception: logger.exception("Error bumping presence active time") diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 0d6cf2b008..5526015ddb 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -24,11 +24,12 @@ The methods that define policy are: import logging from contextlib import contextmanager -from typing import Dict, Set +from typing import Dict, List, Set from six import iteritems, itervalues from prometheus_client import Counter +from typing_extensions import ContextManager from twisted.internet import defer @@ -42,10 +43,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer +MYPY = False +if MYPY: + import synapse.server + logger = logging.getLogger(__name__) @@ -97,7 +102,6 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER class PresenceHandler(object): def __init__(self, hs: "synapse.server.HomeServer"): self.hs = hs - self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id self.server_name = hs.hostname self.clock = hs.get_clock() @@ -150,7 +154,7 @@ class PresenceHandler(object): # Set of users who have presence in the `user_to_current_state` that # have not yet been persisted - self.unpersisted_users_changes = set() + self.unpersisted_users_changes = set() # type: Set[str] hs.get_reactor().addSystemEventTrigger( "before", @@ -160,12 +164,11 @@ class PresenceHandler(object): self._on_shutdown, ) - self.serial_to_user = {} self._next_serial = 1 # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self.user_to_num_current_syncs = {} + self.user_to_num_current_syncs = {} # type: Dict[str, int] # Keeps track of the number of *ongoing* syncs on other processes. # While any sync is ongoing on another process the user will never @@ -213,8 +216,7 @@ class PresenceHandler(object): self._event_pos = self.store.get_current_events_token() self._event_processing = False - @defer.inlineCallbacks - def _on_shutdown(self): + async def _on_shutdown(self): """Gets called when shutting down. This lets us persist any updates that we haven't yet persisted, e.g. updates that only changes some internal timers. This allows changes to persist across startup without having to @@ -235,7 +237,7 @@ class PresenceHandler(object): if self.unpersisted_users_changes: - yield self.store.update_presence( + await self.store.update_presence( [ self.user_to_current_state[user_id] for user_id in self.unpersisted_users_changes @@ -243,8 +245,7 @@ class PresenceHandler(object): ) logger.info("Finished _on_shutdown") - @defer.inlineCallbacks - def _persist_unpersisted_changes(self): + async def _persist_unpersisted_changes(self): """We periodically persist the unpersisted changes, as otherwise they may stack up and slow down shutdown times. """ @@ -253,12 +254,11 @@ class PresenceHandler(object): if unpersisted: logger.info("Persisting %d unpersisted presence updates", len(unpersisted)) - yield self.store.update_presence( + await self.store.update_presence( [self.user_to_current_state[user_id] for user_id in unpersisted] ) - @defer.inlineCallbacks - def _update_states(self, new_states): + async def _update_states(self, new_states): """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state should be sent to clients/servers. @@ -267,7 +267,7 @@ class PresenceHandler(object): with Measure(self.clock, "presence_update_states"): - # NOTE: We purposefully don't yield between now and when we've + # NOTE: We purposefully don't await between now and when we've # calculated what we want to do with the new states, to avoid races. to_notify = {} # Changes we want to notify everyone about @@ -311,7 +311,7 @@ class PresenceHandler(object): if to_notify: notified_presence_counter.inc(len(to_notify)) - yield self._persist_and_notify(list(to_notify.values())) + await self._persist_and_notify(list(to_notify.values())) self.unpersisted_users_changes |= {s.user_id for s in new_states} self.unpersisted_users_changes -= set(to_notify.keys()) @@ -326,7 +326,7 @@ class PresenceHandler(object): self._push_to_remotes(to_federation_ping.values()) - def _handle_timeouts(self): + async def _handle_timeouts(self): """Checks the presence of users that have timed out and updates as appropriate. """ @@ -368,10 +368,9 @@ class PresenceHandler(object): now=now, ) - return self._update_states(changes) + return await self._update_states(changes) - @defer.inlineCallbacks - def bump_presence_active_time(self, user): + async def bump_presence_active_time(self, user): """We've seen the user do something that indicates they're interacting with the app. """ @@ -383,16 +382,17 @@ class PresenceHandler(object): bump_active_time_counter.inc() - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) new_fields = {"last_active_ts": self.clock.time_msec()} if prev_state.state == PresenceState.UNAVAILABLE: new_fields["state"] = PresenceState.ONLINE - yield self._update_states([prev_state.copy_and_replace(**new_fields)]) + await self._update_states([prev_state.copy_and_replace(**new_fields)]) - @defer.inlineCallbacks - def user_syncing(self, user_id, affect_presence=True): + async def user_syncing( + self, user_id: str, affect_presence: bool = True + ) -> ContextManager[None]: """Returns a context manager that should surround any stream requests from the user. @@ -415,11 +415,11 @@ class PresenceHandler(object): curr_sync = self.user_to_num_current_syncs.get(user_id, 0) self.user_to_num_current_syncs[user_id] = curr_sync + 1 - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) if prev_state.state == PresenceState.OFFLINE: # If they're currently offline then bring them online, otherwise # just update the last sync times. - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace( state=PresenceState.ONLINE, @@ -429,7 +429,7 @@ class PresenceHandler(object): ] ) else: - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace( last_user_sync_ts=self.clock.time_msec() @@ -437,13 +437,12 @@ class PresenceHandler(object): ] ) - @defer.inlineCallbacks - def _end(): + async def _end(): try: self.user_to_num_current_syncs[user_id] -= 1 - prev_state = yield self.current_state_for_user(user_id) - yield self._update_states( + prev_state = await self.current_state_for_user(user_id) + await self._update_states( [ prev_state.copy_and_replace( last_user_sync_ts=self.clock.time_msec() @@ -480,8 +479,7 @@ class PresenceHandler(object): else: return set() - @defer.inlineCallbacks - def update_external_syncs_row( + async def update_external_syncs_row( self, process_id, user_id, is_syncing, sync_time_msec ): """Update the syncing users for an external process as a delta. @@ -494,8 +492,8 @@ class PresenceHandler(object): is_syncing (bool): Whether or not the user is now syncing sync_time_msec(int): Time in ms when the user was last syncing """ - with (yield self.external_sync_linearizer.queue(process_id)): - prev_state = yield self.current_state_for_user(user_id) + with (await self.external_sync_linearizer.queue(process_id)): + prev_state = await self.current_state_for_user(user_id) process_presence = self.external_process_to_current_syncs.setdefault( process_id, set() @@ -525,25 +523,24 @@ class PresenceHandler(object): process_presence.discard(user_id) if updates: - yield self._update_states(updates) + await self._update_states(updates) self.external_process_last_updated_ms[process_id] = self.clock.time_msec() - @defer.inlineCallbacks - def update_external_syncs_clear(self, process_id): + async def update_external_syncs_clear(self, process_id): """Marks all users that had been marked as syncing by a given process as offline. Used when the process has stopped/disappeared. """ - with (yield self.external_sync_linearizer.queue(process_id)): + with (await self.external_sync_linearizer.queue(process_id)): process_presence = self.external_process_to_current_syncs.pop( process_id, set() ) - prev_states = yield self.current_state_for_users(process_presence) + prev_states = await self.current_state_for_users(process_presence) time_now_ms = self.clock.time_msec() - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) for prev_state in itervalues(prev_states) @@ -551,15 +548,13 @@ class PresenceHandler(object): ) self.external_process_last_updated_ms.pop(process_id, None) - @defer.inlineCallbacks - def current_state_for_user(self, user_id): + async def current_state_for_user(self, user_id): """Get the current presence state for a user. """ - res = yield self.current_state_for_users([user_id]) + res = await self.current_state_for_users([user_id]) return res[user_id] - @defer.inlineCallbacks - def current_state_for_users(self, user_ids): + async def current_state_for_users(self, user_ids): """Get the current presence state for multiple users. Returns: @@ -574,7 +569,7 @@ class PresenceHandler(object): if missing: # There are things not in our in memory cache. Lets pull them out of # the database. - res = yield self.store.get_presence_for_users(missing) + res = await self.store.get_presence_for_users(missing) states.update(res) missing = [user_id for user_id, state in iteritems(states) if not state] @@ -587,14 +582,13 @@ class PresenceHandler(object): return states - @defer.inlineCallbacks - def _persist_and_notify(self, states): + async def _persist_and_notify(self, states): """Persist states in the database, poke the notifier and send to interested remote servers """ - stream_id, max_token = yield self.store.update_presence(states) + stream_id, max_token = await self.store.update_presence(states) - parties = yield get_interested_parties(self.store, states) + parties = await get_interested_parties(self.store, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -606,9 +600,8 @@ class PresenceHandler(object): self._push_to_remotes(states) - @defer.inlineCallbacks - def notify_for_states(self, state, stream_id): - parties = yield get_interested_parties(self.store, [state]) + async def notify_for_states(self, state, stream_id): + parties = await get_interested_parties(self.store, [state]) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -626,8 +619,7 @@ class PresenceHandler(object): """ self.federation.send_presence(states) - @defer.inlineCallbacks - def incoming_presence(self, origin, content): + async def incoming_presence(self, origin, content): """Called when we receive a `m.presence` EDU from a remote server. """ now = self.clock.time_msec() @@ -670,21 +662,19 @@ class PresenceHandler(object): new_fields["status_msg"] = push.get("status_msg", None) new_fields["currently_active"] = push.get("currently_active", False) - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) updates.append(prev_state.copy_and_replace(**new_fields)) if updates: federation_presence_counter.inc(len(updates)) - yield self._update_states(updates) + await self._update_states(updates) - @defer.inlineCallbacks - def get_state(self, target_user, as_event=False): - results = yield self.get_states([target_user.to_string()], as_event=as_event) + async def get_state(self, target_user, as_event=False): + results = await self.get_states([target_user.to_string()], as_event=as_event) return results[0] - @defer.inlineCallbacks - def get_states(self, target_user_ids, as_event=False): + async def get_states(self, target_user_ids, as_event=False): """Get the presence state for users. Args: @@ -695,7 +685,7 @@ class PresenceHandler(object): list """ - updates = yield self.current_state_for_users(target_user_ids) + updates = await self.current_state_for_users(target_user_ids) updates = list(updates.values()) for user_id in set(target_user_ids) - {u.user_id for u in updates}: @@ -713,8 +703,7 @@ class PresenceHandler(object): else: return updates - @defer.inlineCallbacks - def set_state(self, target_user, state, ignore_status_msg=False): + async def set_state(self, target_user, state, ignore_status_msg=False): """Set the presence state of the user. """ status_msg = state.get("status_msg", None) @@ -730,7 +719,7 @@ class PresenceHandler(object): user_id = target_user.to_string() - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) new_fields = {"state": presence} @@ -741,16 +730,15 @@ class PresenceHandler(object): if presence == PresenceState.ONLINE: new_fields["last_active_ts"] = self.clock.time_msec() - yield self._update_states([prev_state.copy_and_replace(**new_fields)]) + await self._update_states([prev_state.copy_and_replace(**new_fields)]) - @defer.inlineCallbacks - def is_visible(self, observed_user, observer_user): + async def is_visible(self, observed_user, observer_user): """Returns whether a user can see another user's presence. """ - observer_room_ids = yield self.store.get_rooms_for_user( + observer_room_ids = await self.store.get_rooms_for_user( observer_user.to_string() ) - observed_room_ids = yield self.store.get_rooms_for_user( + observed_room_ids = await self.store.get_rooms_for_user( observed_user.to_string() ) @@ -759,8 +747,7 @@ class PresenceHandler(object): return False - @defer.inlineCallbacks - def get_all_presence_updates(self, last_id, current_id): + async def get_all_presence_updates(self, last_id, current_id): """ Gets a list of presence update rows from between the given stream ids. Each row has: @@ -775,7 +762,7 @@ class PresenceHandler(object): """ # TODO(markjh): replicate the unpersisted changes. # This could use the in-memory stores for recent changes. - rows = yield self.store.get_all_presence_updates(last_id, current_id) + rows = await self.store.get_all_presence_updates(last_id, current_id) return rows def notify_new_event(self): @@ -786,20 +773,18 @@ class PresenceHandler(object): if self._event_processing: return - @defer.inlineCallbacks - def _process_presence(): + async def _process_presence(): assert not self._event_processing self._event_processing = True try: - yield self._unsafe_process() + await self._unsafe_process() finally: self._event_processing = False run_as_background_process("presence.notify_new_event", _process_presence) - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "presence_delta"): @@ -812,10 +797,10 @@ class PresenceHandler(object): self._event_pos, room_max_stream_ordering, ) - max_pos, deltas = yield self.store.get_current_state_deltas( + max_pos, deltas = await self.store.get_current_state_deltas( self._event_pos, room_max_stream_ordering ) - yield self._handle_state_delta(deltas) + await self._handle_state_delta(deltas) self._event_pos = max_pos @@ -824,8 +809,7 @@ class PresenceHandler(object): max_pos ) - @defer.inlineCallbacks - def _handle_state_delta(self, deltas): + async def _handle_state_delta(self, deltas): """Process current state deltas to find new joins that need to be handled. """ @@ -846,13 +830,13 @@ class PresenceHandler(object): # joins. continue - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not event or event.content.get("membership") != Membership.JOIN: # We only care about joins continue if prev_event_id: - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) if ( prev_event and prev_event.content.get("membership") == Membership.JOIN @@ -860,10 +844,9 @@ class PresenceHandler(object): # Ignore changes to join events. continue - yield self._on_user_joined_room(room_id, state_key) + await self._on_user_joined_room(room_id, state_key) - @defer.inlineCallbacks - def _on_user_joined_room(self, room_id, user_id): + async def _on_user_joined_room(self, room_id, user_id): """Called when we detect a user joining the room via the current state delta stream. @@ -882,8 +865,8 @@ class PresenceHandler(object): # TODO: We should be able to filter the hosts down to those that # haven't previously seen the user - state = yield self.current_state_for_user(user_id) - hosts = yield self.state.get_current_hosts_in_room(room_id) + state = await self.current_state_for_user(user_id) + hosts = await self.state.get_current_hosts_in_room(room_id) # Filter out ourselves. hosts = {host for host in hosts if host != self.server_name} @@ -903,10 +886,10 @@ class PresenceHandler(object): # TODO: Check that this is actually a new server joining the # room. - user_ids = yield self.state.get_current_users_in_room(room_id) + user_ids = await self.state.get_current_users_in_room(room_id) user_ids = list(filter(self.is_mine_id, user_ids)) - states = yield self.current_state_for_users(user_ids) + states = await self.current_state_for_users(user_ids) # Filter out old presence, i.e. offline presence states where # the user hasn't been active for a week. We can change this @@ -996,9 +979,8 @@ class PresenceEventSource(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() - @defer.inlineCallbacks @log_function - def get_new_events( + async def get_new_events( self, user, from_key, @@ -1045,7 +1027,7 @@ class PresenceEventSource(object): presence = self.get_presence_handler() stream_change_cache = self.store.presence_stream_cache - users_interested_in = yield self._get_interested_in(user, explicit_room_id) + users_interested_in = await self._get_interested_in(user, explicit_room_id) user_ids_changed = set() changed = None @@ -1071,7 +1053,7 @@ class PresenceEventSource(object): else: user_ids_changed = users_interested_in - updates = yield presence.current_state_for_users(user_ids_changed) + updates = await presence.current_state_for_users(user_ids_changed) if include_offline: return (list(updates.values()), max_token) @@ -1084,11 +1066,11 @@ class PresenceEventSource(object): def get_current_key(self): return self.store.get_current_presence_token() - def get_pagination_rows(self, user, pagination_config, key): - return self.get_new_events(user, from_key=None, include_offline=False) + async def get_pagination_rows(self, user, pagination_config, key): + return await self.get_new_events(user, from_key=None, include_offline=False) - @cachedInlineCallbacks(num_args=2, cache_context=True) - def _get_interested_in(self, user, explicit_room_id, cache_context): + @cached(num_args=2, cache_context=True) + async def _get_interested_in(self, user, explicit_room_id, cache_context): """Returns the set of users that the given user should see presence updates for """ @@ -1096,13 +1078,13 @@ class PresenceEventSource(object): users_interested_in = set() users_interested_in.add(user_id) # So that we receive our own presence - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(users_who_share_room) if explicit_room_id: - user_ids = yield self.store.get_users_in_room( + user_ids = await self.store.get_users_in_room( explicit_room_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(user_ids) @@ -1277,8 +1259,8 @@ def get_interested_parties(store, states): 2-tuple: `(room_ids_to_states, users_to_states)`, with each item being a dict of `entity_name` -> `[UserPresenceState]` """ - room_ids_to_states = {} - users_to_states = {} + room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]] + users_to_states = {} # type: Dict[str, List[UserPresenceState]] for state in states: room_ids = yield store.get_rooms_for_user(state.user_id) for room_id in room_ids: diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index ce60ae2e07..ce9d1fae12 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -323,7 +323,11 @@ class ReplicationStreamer(object): # We need to tell the presence handler that the connection has been # lost so that it can handle any ongoing syncs on that connection. - self.presence_handler.update_external_syncs_clear(connection.conn_id) + run_as_background_process( + "update_external_syncs_clear", + self.presence_handler.update_external_syncs_clear, + connection.conn_id, + ) def _batch_updates(updates): diff --git a/synapse/server.pyi b/synapse/server.pyi index 40eabfe5d9..3844f0e12f 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -3,6 +3,7 @@ import twisted.internet import synapse.api.auth import synapse.config.homeserver import synapse.crypto.keyring +import synapse.federation.federation_server import synapse.federation.sender import synapse.federation.transport.client import synapse.handlers @@ -107,5 +108,9 @@ class HomeServer(object): self, ) -> synapse.replication.tcp.client.ReplicationClientHandler: pass + def get_federation_registry( + self, + ) -> synapse.federation.federation_server.FederationHandlerRegistry: + pass def is_mine_id(self, domain_id: str) -> bool: pass diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 64915bafcd..05ea40a7de 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -494,8 +494,10 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.helper.join(room_id, "@test2:server") # Mark test2 as online, test will be offline with a last_active of 0 - self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + ) ) self.reactor.pump([0]) # Wait for presence updates to be handled @@ -543,14 +545,18 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): room_id = self.helper.create_room_as(self.user_id) # Mark test as online - self.presence_handler.set_state( - UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + ) ) # Mark test2 as online, test will be offline with a last_active of 0. # Note we don't join them to the room yet - self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + ) ) # Add servers to the room diff --git a/tox.ini b/tox.ini index b715ea0bff..4ccfde01b5 100644 --- a/tox.ini +++ b/tox.ini @@ -183,6 +183,7 @@ commands = mypy \ synapse/events/spamcheck.py \ synapse/federation/sender \ synapse/federation/transport \ + synapse/handlers/presence.py \ synapse/handlers/sync.py \ synapse/handlers/ui_auth \ synapse/logging/ \ -- cgit 1.4.1 From 4f21c33be301b8ea6369039c3ad8baa51878e4d5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Mar 2020 16:37:24 +0100 Subject: Remove usage of "conn_id" for presence. (#7128) * Remove `conn_id` usage for UserSyncCommand. Each tcp replication connection is assigned a "conn_id", which is used to give an ID to a remotely connected worker. In a redis world, there will no longer be a one to one mapping between connection and instance, so instead we need to replace such usages with an ID generated by the remote instances and included in the replicaiton commands. This really only effects UserSyncCommand. * Add CLEAR_USER_SYNCS command that is sent on shutdown. This should help with the case where a synchrotron gets restarted gracefully, rather than rely on 5 minute timeout. --- changelog.d/7128.misc | 1 + docs/tcp_replication.md | 6 ++++++ synapse/app/generic_worker.py | 20 ++++++++++++++++---- synapse/replication/tcp/client.py | 6 ++++-- synapse/replication/tcp/commands.py | 36 ++++++++++++++++++++++++++++++++---- synapse/replication/tcp/protocol.py | 9 +++++++-- synapse/replication/tcp/resource.py | 17 +++++++---------- synapse/server.py | 11 +++++++++++ synapse/server.pyi | 2 ++ 9 files changed, 86 insertions(+), 22 deletions(-) create mode 100644 changelog.d/7128.misc (limited to 'synapse/server.pyi') diff --git a/changelog.d/7128.misc b/changelog.d/7128.misc new file mode 100644 index 0000000000..5703f6d2ec --- /dev/null +++ b/changelog.d/7128.misc @@ -0,0 +1 @@ +Add explicit `instance_id` for USER_SYNC commands and remove implicit `conn_id` usage. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index d4f7d9ec18..3be8e50c4c 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -198,6 +198,12 @@ Asks the server for the current position of all streams. A user has started or stopped syncing +#### CLEAR_USER_SYNC (C) + + The server should clear all associated user sync data from the worker. + + This is used when a worker is shutting down. + #### FEDERATION_ACK (C) Acknowledge receipt of some federation data diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index fba7ad9551..1ee266f7c5 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -65,6 +65,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import ( AccountDataStream, DeviceListsStream, @@ -124,7 +125,6 @@ from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole -from synapse.util.stringutils import random_string from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.generic_worker") @@ -233,6 +233,7 @@ class GenericWorkerPresence(object): self.user_to_num_current_syncs = {} self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self.instance_id = hs.get_instance_id() active_presence = self.store.take_presence_startup_info() self.user_to_current_state = {state.user_id: state for state in active_presence} @@ -245,13 +246,24 @@ class GenericWorkerPresence(object): self.send_stop_syncing, UPDATE_SYNCING_USERS_MS ) - self.process_id = random_string(16) - logger.info("Presence process_id is %r", self.process_id) + hs.get_reactor().addSystemEventTrigger( + "before", + "shutdown", + run_as_background_process, + "generic_presence.on_shutdown", + self._on_shutdown, + ) + + def _on_shutdown(self): + if self.hs.config.use_presence: + self.hs.get_tcp_replication().send_command( + ClearUserSyncsCommand(self.instance_id) + ) def send_user_sync(self, user_id, is_syncing, last_sync_ms): if self.hs.config.use_presence: self.hs.get_tcp_replication().send_user_sync( - user_id, is_syncing, last_sync_ms + self.instance_id, user_id, is_syncing, last_sync_ms ) def mark_as_coming_online(self, user_id): diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 7e7ad0f798..e86d9805f1 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -189,10 +189,12 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): """ self.send_command(FederationAckCommand(token)) - def send_user_sync(self, user_id, is_syncing, last_sync_ms): + def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): """Poke the master that a user has started/stopped syncing. """ - self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms)) + self.send_command( + UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + ) def send_remove_pusher(self, app_id, push_key, user_id): """Poke the master to remove a pusher for a user diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 5a6b734094..e4eec643f7 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -207,30 +207,32 @@ class UserSyncCommand(Command): Format:: - USER_SYNC + USER_SYNC Where is either "start" or "stop" """ NAME = "USER_SYNC" - def __init__(self, user_id, is_syncing, last_sync_ms): + def __init__(self, instance_id, user_id, is_syncing, last_sync_ms): + self.instance_id = instance_id self.user_id = user_id self.is_syncing = is_syncing self.last_sync_ms = last_sync_ms @classmethod def from_line(cls, line): - user_id, state, last_sync_ms = line.split(" ", 2) + instance_id, user_id, state, last_sync_ms = line.split(" ", 3) if state not in ("start", "end"): raise Exception("Invalid USER_SYNC state %r" % (state,)) - return cls(user_id, state == "start", int(last_sync_ms)) + return cls(instance_id, user_id, state == "start", int(last_sync_ms)) def to_line(self): return " ".join( ( + self.instance_id, self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms), @@ -238,6 +240,30 @@ class UserSyncCommand(Command): ) +class ClearUserSyncsCommand(Command): + """Sent by the client to inform the server that it should drop all + information about syncing users sent by the client. + + Mainly used when client is about to shut down. + + Format:: + + CLEAR_USER_SYNC + """ + + NAME = "CLEAR_USER_SYNC" + + def __init__(self, instance_id): + self.instance_id = instance_id + + @classmethod + def from_line(cls, line): + return cls(line) + + def to_line(self): + return self.instance_id + + class FederationAckCommand(Command): """Sent by the client when it has processed up to a given point in the federation stream. This allows the master to drop in-memory caches of the @@ -398,6 +424,7 @@ _COMMANDS = ( InvalidateCacheCommand, UserIpCommand, RemoteServerUpCommand, + ClearUserSyncsCommand, ) # type: Tuple[Type[Command], ...] # Map of command name to command type. @@ -420,6 +447,7 @@ VALID_CLIENT_COMMANDS = ( ReplicateCommand.NAME, PingCommand.NAME, UserSyncCommand.NAME, + ClearUserSyncsCommand.NAME, FederationAckCommand.NAME, RemovePusherCommand.NAME, InvalidateCacheCommand.NAME, diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index f81d2e2442..dae246825f 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -423,9 +423,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): async def on_USER_SYNC(self, cmd): await self.streamer.on_user_sync( - self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms + cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) + async def on_CLEAR_USER_SYNC(self, cmd): + await self.streamer.on_clear_user_syncs(cmd.instance_id) + async def on_REPLICATE(self, cmd): # Subscribe to all streams we're publishing to. for stream_name in self.streamer.streams_by_name: @@ -551,6 +554,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): ): BaseReplicationStreamProtocol.__init__(self, clock) + self.instance_id = hs.get_instance_id() + self.client_name = client_name self.server_name = server_name self.handler = handler @@ -580,7 +585,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): currently_syncing = self.handler.get_currently_syncing_users() now = self.clock.time_msec() for user_id in currently_syncing: - self.send_command(UserSyncCommand(user_id, True, now)) + self.send_command(UserSyncCommand(self.instance_id, user_id, True, now)) # We've now finished connecting to so inform the client handler self.handler.update_connection(self) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 4374e99e32..8b6067e20d 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -251,14 +251,19 @@ class ReplicationStreamer(object): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") - async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): + async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() await self.presence_handler.update_external_syncs_row( - conn_id, user_id, is_syncing, last_sync_ms + instance_id, user_id, is_syncing, last_sync_ms ) + async def on_clear_user_syncs(self, instance_id): + """A replication client wants us to drop all their UserSync data. + """ + await self.presence_handler.update_external_syncs_clear(instance_id) + @measure_func("repl.on_remove_pusher") async def on_remove_pusher(self, app_id, push_key, user_id): """A client has asked us to remove a pusher @@ -321,14 +326,6 @@ class ReplicationStreamer(object): except ValueError: pass - # We need to tell the presence handler that the connection has been - # lost so that it can handle any ongoing syncs on that connection. - run_as_background_process( - "update_external_syncs_clear", - self.presence_handler.update_external_syncs_clear, - connection.conn_id, - ) - def _batch_updates(updates): """Takes a list of updates of form [(token, row)] and sets the token to diff --git a/synapse/server.py b/synapse/server.py index c7ca2bda0d..cd86475d6b 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -103,6 +103,7 @@ from synapse.storage import DataStores, Storage from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor +from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) @@ -230,6 +231,8 @@ class HomeServer(object): self._listening_services = [] self.start_time = None + self.instance_id = random_string(5) + self.clock = Clock(reactor) self.distributor = Distributor() self.ratelimiter = Ratelimiter() @@ -242,6 +245,14 @@ class HomeServer(object): for depname in kwargs: setattr(self, depname, kwargs[depname]) + def get_instance_id(self): + """A unique ID for this synapse process instance. + + This is used to distinguish running instances in worker-based + deployments. + """ + return self.instance_id + def setup(self): logger.info("Setting up.") self.start_time = int(self.get_clock().time()) diff --git a/synapse/server.pyi b/synapse/server.pyi index 3844f0e12f..9d1dfa71e7 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -114,3 +114,5 @@ class HomeServer(object): pass def is_mine_id(self, domain_id: str) -> bool: pass + def get_instance_id(self) -> str: + pass -- cgit 1.4.1 From 5016b162fcf0372fe35404c64f80aeaf21461f31 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Apr 2020 09:58:42 +0100 Subject: Move client command handling out of TCP protocol (#7185) The aim here is to move the command handling out of the TCP protocol classes and to also merge the client and server command handling (so that we can reuse them for redis protocol). This PR simply moves the client paths to the new `ReplicationCommandHandler`, a future PR will move the server paths too. --- changelog.d/7185.misc | 1 + synapse/app/admin_cmd.py | 12 -- synapse/app/generic_worker.py | 9 +- synapse/replication/tcp/__init__.py | 30 ++- synapse/replication/tcp/client.py | 179 +++--------------- synapse/replication/tcp/handler.py | 252 +++++++++++++++++++++++++ synapse/replication/tcp/protocol.py | 197 +++---------------- synapse/server.py | 8 +- synapse/server.pyi | 7 +- tests/replication/slave/storage/_base.py | 15 +- tests/replication/tcp/streams/_base.py | 38 ++-- tests/replication/tcp/streams/test_receipts.py | 1 - 12 files changed, 378 insertions(+), 371 deletions(-) create mode 100644 changelog.d/7185.misc create mode 100644 synapse/replication/tcp/handler.py (limited to 'synapse/server.pyi') diff --git a/changelog.d/7185.misc b/changelog.d/7185.misc new file mode 100644 index 0000000000..deb9ca7021 --- /dev/null +++ b/changelog.d/7185.misc @@ -0,0 +1 @@ +Move client command handling out of TCP protocol. diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 1c7c6ec0c8..a37818fe9a 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -43,7 +43,6 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -79,17 +78,6 @@ class AdminCmdServer(HomeServer): def start_listening(self, listeners): pass - def build_tcp_replication(self): - return AdminCmdReplicationHandler(self) - - -class AdminCmdReplicationHandler(ReplicationClientHandler): - async def on_rdata(self, stream_name, token, rows): - pass - - def get_streams_to_replicate(self): - return {} - @defer.inlineCallbacks def export_data_command(hs, args): diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 174bef360f..dcd0709a02 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -64,7 +64,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import ( AccountDataStream, @@ -603,7 +603,7 @@ class GenericWorkerServer(HomeServer): def remove_pusher(self, app_id, push_key, user_id): self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) - def build_tcp_replication(self): + def build_replication_data_handler(self): return GenericWorkerReplicationHandler(self) def build_presence_handler(self): @@ -613,7 +613,7 @@ class GenericWorkerServer(HomeServer): return GenericWorkerTyping(self) -class GenericWorkerReplicationHandler(ReplicationClientHandler): +class GenericWorkerReplicationHandler(ReplicationDataHandler): def __init__(self, hs): super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore()) @@ -644,9 +644,6 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler): args.update(self.send_handler.stream_positions()) return args - def get_currently_syncing_users(self): - return self.presence_handler.get_currently_syncing_users() - async def process_and_notify(self, stream_name, token, rows): try: if self.send_handler: diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py index 81c2ea7ee9..523a1358d4 100644 --- a/synapse/replication/tcp/__init__.py +++ b/synapse/replication/tcp/__init__.py @@ -20,11 +20,31 @@ Further details can be found in docs/tcp_replication.rst Structure of the module: - * client.py - the client classes used for workers to connect to master + * handler.py - the classes used to handle sending/receiving commands to + replication * command.py - the definitions of all the valid commands - * protocol.py - contains bot the client and server protocol implementations, - these should not be used directly - * resource.py - the server classes that accepts and handle client connections - * streams.py - the definitons of all the valid streams + * protocol.py - the TCP protocol classes + * resource.py - handles streaming stream updates to replications + * streams/ - the definitons of all the valid streams + +The general interaction of the classes are: + + +---------------------+ + | ReplicationStreamer | + +---------------------+ + | + v + +---------------------------+ +----------------------+ + | ReplicationCommandHandler |---->|ReplicationDataHandler| + +---------------------------+ +----------------------+ + | ^ + v | + +-------------+ + | Protocols | + | (TCP/redis) | + +-------------+ + +Where the ReplicationDataHandler (or subclasses) handles incoming stream +updates. """ diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e86d9805f1..700ae79158 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,26 +16,16 @@ """ import logging -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict -from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.tcp.protocol import ( - AbstractReplicationClientHandler, - ClientReplicationStreamProtocol, -) - -from .commands import ( - Command, - FederationAckCommand, - InvalidateCacheCommand, - RemoteServerUpCommand, - RemovePusherCommand, - UserIpCommand, - UserSyncCommand, -) +from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol + +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.replication.tcp.handler import ReplicationCommandHandler logger = logging.getLogger(__name__) @@ -44,16 +34,20 @@ class ReplicationClientFactory(ReconnectingClientFactory): """Factory for building connections to the master. Will reconnect if the connection is lost. - Accepts a handler that will be called when new data is available or data - is required. + Accepts a handler that is passed to `ClientReplicationStreamProtocol`. """ initialDelay = 0.1 maxDelay = 1 # Try at least once every N seconds - def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): + def __init__( + self, + hs: "HomeServer", + client_name: str, + command_handler: "ReplicationCommandHandler", + ): self.client_name = client_name - self.handler = handler + self.command_handler = command_handler self.server_name = hs.config.server_name self.hs = hs self._clock = hs.get_clock() # As self.clock is defined in super class @@ -66,7 +60,11 @@ class ReplicationClientFactory(ReconnectingClientFactory): def buildProtocol(self, addr): logger.info("Connected to replication: %r", addr) return ClientReplicationStreamProtocol( - self.hs, self.client_name, self.server_name, self._clock, self.handler, + self.hs, + self.client_name, + self.server_name, + self._clock, + self.command_handler, ) def clientConnectionLost(self, connector, reason): @@ -78,41 +76,17 @@ class ReplicationClientFactory(ReconnectingClientFactory): ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) -class ReplicationClientHandler(AbstractReplicationClientHandler): - """A base handler that can be passed to the ReplicationClientFactory. +class ReplicationDataHandler: + """Handles incoming stream updates from replication. - By default proxies incoming replication data to the SlaveStore. + This instance notifies the slave data store about updates. Can be subclassed + to handle updates in additional ways. """ def __init__(self, store: BaseSlavedStore): self.store = store - # The current connection. None if we are currently (re)connecting - self.connection = None - - # Any pending commands to be sent once a new connection has been - # established - self.pending_commands = [] # type: List[Command] - - # Map from string -> deferred, to wake up when receiveing a SYNC with - # the given string. - # Used for tests. - self.awaiting_syncs = {} # type: Dict[str, defer.Deferred] - - # The factory used to create connections. - self.factory = None # type: Optional[ReplicationClientFactory] - - def start_replication(self, hs): - """Helper method to start a replication connection to the remote server - using TCP. - """ - client_name = hs.config.worker_name - self.factory = ReplicationClientFactory(hs, client_name, self) - host = hs.config.worker_replication_host - port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self.factory) - - async def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name: str, token: int, rows: list): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to @@ -124,30 +98,8 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - logger.debug("Received rdata %s -> %s", stream_name, token) self.store.process_replication_rows(stream_name, token, rows) - async def on_position(self, stream_name, token): - """Called when we get new position data. By default this just pokes - the slave store. - - Can be overriden in subclasses to handle more. - """ - self.store.process_replication_rows(stream_name, token, []) - - def on_sync(self, data): - """When we received a SYNC we wake up any deferreds that were waiting - for the sync with the given data. - - Used by tests. - """ - d = self.awaiting_syncs.pop(data, None) - if d: - d.callback(data) - - def on_remote_server_up(self, server: str): - """Called when get a new REMOTE_SERVER_UP command.""" - def get_streams_to_replicate(self) -> Dict[str, int]: """Called when a new connection has been established and we need to subscribe to streams. @@ -163,85 +115,10 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): args["account_data"] = user_account_data elif room_account_data: args["account_data"] = room_account_data - return args - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users. (Overriden by the synchrotron's only) - """ - return [] - - def send_command(self, cmd): - """Send a command to master (when we get establish a connection if we - don't have one already.) - """ - if self.connection: - self.connection.send_command(cmd) - else: - logger.warning("Queuing command as not connected: %r", cmd.NAME) - self.pending_commands.append(cmd) - - def send_federation_ack(self, token): - """Ack data for the federation stream. This allows the master to drop - data stored purely in memory. - """ - self.send_command(FederationAckCommand(token)) - - def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): - """Poke the master that a user has started/stopped syncing. - """ - self.send_command( - UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) - ) - - def send_remove_pusher(self, app_id, push_key, user_id): - """Poke the master to remove a pusher for a user - """ - cmd = RemovePusherCommand(app_id, push_key, user_id) - self.send_command(cmd) - - def send_invalidate_cache(self, cache_func, keys): - """Poke the master to invalidate a cache. - """ - cmd = InvalidateCacheCommand(cache_func.__name__, keys) - self.send_command(cmd) - - def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): - """Tell the master that the user made a request. - """ - cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) - self.send_command(cmd) - - def send_remote_server_up(self, server: str): - self.send_command(RemoteServerUpCommand(server)) - - def await_sync(self, data): - """Returns a deferred that is resolved when we receive a SYNC command - with given data. - - [Not currently] used by tests. - """ - return self.awaiting_syncs.setdefault(data, defer.Deferred()) - - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). - """ - self.connection = connection - if connection: - for cmd in self.pending_commands: - connection.send_command(cmd) - self.pending_commands = [] - - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. - """ - logger.info("Finished connecting to server") + async def on_position(self, stream_name: str, token: int): + self.store.process_replication_rows(stream_name, token, []) - # We don't reset the delay any earlier as otherwise if there is a - # problem during start up we'll end up tight looping connecting to the - # server. - if self.factory: - self.factory.resetDelay() + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py new file mode 100644 index 0000000000..12a1cfd6d1 --- /dev/null +++ b/synapse/replication/tcp/handler.py @@ -0,0 +1,252 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Callable, Dict, List, Optional, Set + +from prometheus_client import Counter + +from synapse.replication.tcp.client import ReplicationClientFactory +from synapse.replication.tcp.commands import ( + Command, + FederationAckCommand, + InvalidateCacheCommand, + PositionCommand, + RdataCommand, + RemoteServerUpCommand, + RemovePusherCommand, + SyncCommand, + UserIpCommand, + UserSyncCommand, +) +from synapse.replication.tcp.streams import STREAMS_MAP, Stream +from synapse.util.async_helpers import Linearizer + +logger = logging.getLogger(__name__) + + +# number of updates received for each RDATA stream +inbound_rdata_count = Counter( + "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] +) + + +class ReplicationCommandHandler: + """Handles incoming commands from replication as well as sending commands + back out to connections. + """ + + def __init__(self, hs): + self._replication_data_handler = hs.get_replication_data_handler() + self._presence_handler = hs.get_presence_handler() + + # Set of streams that we've caught up with. + self._streams_connected = set() # type: Set[str] + + self._streams = { + stream.NAME: stream(hs) for stream in STREAMS_MAP.values() + } # type: Dict[str, Stream] + + self._position_linearizer = Linearizer("replication_position") + + # Map of stream to batched updates. See RdataCommand for info on how + # batching works. + self._pending_batches = {} # type: Dict[str, List[Any]] + + # The factory used to create connections. + self._factory = None # type: Optional[ReplicationClientFactory] + + # The current connection. None if we are currently (re)connecting + self._connection = None + + def start_replication(self, hs): + """Helper method to start a replication connection to the remote server + using TCP. + """ + client_name = hs.config.worker_name + self._factory = ReplicationClientFactory(hs, client_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + hs.get_reactor().connectTCP(host, port, self._factory) + + async def on_RDATA(self, cmd: RdataCommand): + stream_name = cmd.stream_name + inbound_rdata_count.labels(stream_name).inc() + + try: + row = STREAMS_MAP[stream_name].parse_row(cmd.row) + except Exception: + logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) + raise + + if cmd.token is None or stream_name not in self._streams_connected: + # I.e. either this is part of a batch of updates for this stream (in + # which case batch until we get an update for the stream with a non + # None token) or we're currently connecting so we queue up rows. + self._pending_batches.setdefault(stream_name, []).append(row) + else: + # Check if this is the last of a batch of updates + rows = self._pending_batches.pop(stream_name, []) + rows.append(row) + await self.on_rdata(stream_name, cmd.token, rows) + + async def on_rdata(self, stream_name: str, token: int, rows: list): + """Called to handle a batch of replication data with a given stream token. + + Args: + stream_name: name of the replication stream for this batch of rows + token: stream token for this batch of rows + rows: a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + """ + logger.debug("Received rdata %s -> %s", stream_name, token) + await self._replication_data_handler.on_rdata(stream_name, token, rows) + + async def on_POSITION(self, cmd: PositionCommand): + stream = self._streams.get(cmd.stream_name) + if not stream: + logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + return + + # We protect catching up with a linearizer in case the replication + # connection reconnects under us. + with await self._position_linearizer.queue(cmd.stream_name): + # We're about to go and catch up with the stream, so mark as connecting + # to stop RDATA being handled at the same time by removing stream from + # list of connected streams. We also clear any batched up RDATA from + # before we got the POSITION. + self._streams_connected.discard(cmd.stream_name) + self._pending_batches.clear() + + # Find where we previously streamed up to. + current_token = self._replication_data_handler.get_streams_to_replicate().get( + cmd.stream_name + ) + if current_token is None: + logger.warning( + "Got POSITION for stream we're not subscribed to: %s", + cmd.stream_name, + ) + return + + # Fetch all updates between then and now. + limited = True + while limited: + updates, current_token, limited = await stream.get_updates_since( + current_token, cmd.token + ) + if updates: + await self.on_rdata( + cmd.stream_name, + current_token, + [stream.parse_row(update[1]) for update in updates], + ) + + # We've now caught up to position sent to us, notify handler. + await self._replication_data_handler.on_position(cmd.stream_name, cmd.token) + + # Handle any RDATA that came in while we were catching up. + rows = self._pending_batches.pop(cmd.stream_name, []) + if rows: + await self._replication_data_handler.on_rdata( + cmd.stream_name, rows[-1].token, rows + ) + + self._streams_connected.add(cmd.stream_name) + + async def on_SYNC(self, cmd: SyncCommand): + pass + + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + """"Called when get a new REMOTE_SERVER_UP command.""" + self._replication_data_handler.on_remote_server_up(cmd.data) + + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users. + """ + return self._presence_handler.get_currently_syncing_users() + + def update_connection(self, connection): + """Called when a connection has been established (or lost with None). + """ + self._connection = connection + + def finished_connecting(self): + """Called when we have successfully subscribed and caught up to all + streams we're interested in. + """ + logger.info("Finished connecting to server") + + # We don't reset the delay any earlier as otherwise if there is a + # problem during start up we'll end up tight looping connecting to the + # server. + if self._factory: + self._factory.resetDelay() + + def send_command(self, cmd: Command): + """Send a command to master (when we get establish a connection if we + don't have one already.) + """ + if self._connection: + self._connection.send_command(cmd) + else: + logger.warning("Dropping command as not connected: %r", cmd.NAME) + + def send_federation_ack(self, token: int): + """Ack data for the federation stream. This allows the master to drop + data stored purely in memory. + """ + self.send_command(FederationAckCommand(token)) + + def send_user_sync( + self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + ): + """Poke the master that a user has started/stopped syncing. + """ + self.send_command( + UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + ) + + def send_remove_pusher(self, app_id: str, push_key: str, user_id: str): + """Poke the master to remove a pusher for a user + """ + cmd = RemovePusherCommand(app_id, push_key, user_id) + self.send_command(cmd) + + def send_invalidate_cache(self, cache_func: Callable, keys: tuple): + """Poke the master to invalidate a cache. + """ + cmd = InvalidateCacheCommand(cache_func.__name__, keys) + self.send_command(cmd) + + def send_user_ip( + self, + user_id: str, + access_token: str, + ip: str, + user_agent: str, + device_id: str, + last_seen: int, + ): + """Tell the master that the user made a request. + """ + cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) + self.send_command(cmd) + + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index dae246825f..f2a37f568e 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -46,12 +46,11 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ -import abc import fcntl import logging import struct from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Set +from typing import TYPE_CHECKING, DefaultDict, List from six import iteritems @@ -78,13 +77,12 @@ from synapse.replication.tcp.commands import ( SyncCommand, UserSyncCommand, ) -from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string -MYPY = False -if MYPY: +if TYPE_CHECKING: + from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.server import HomeServer @@ -475,71 +473,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.streamer.lost_connection(self) -class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): - """ - The interface for the handler that should be passed to - ClientReplicationStreamProtocol - """ - - @abc.abstractmethod - async def on_rdata(self, stream_name, token, rows): - """Called to handle a batch of replication data with a given stream token. - - Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. - """ - raise NotImplementedError() - - @abc.abstractmethod - async def on_position(self, stream_name, token): - """Called when we get new position data.""" - raise NotImplementedError() - - @abc.abstractmethod - def on_sync(self, data): - """Called when get a new SYNC command.""" - raise NotImplementedError() - - @abc.abstractmethod - async def on_remote_server_up(self, server: str): - """Called when get a new REMOTE_SERVER_UP command.""" - raise NotImplementedError() - - @abc.abstractmethod - def get_streams_to_replicate(self): - """Called when a new connection has been established and we need to - subscribe to streams. - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - raise NotImplementedError() - - @abc.abstractmethod - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users.""" - raise NotImplementedError() - - @abc.abstractmethod - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). - """ - raise NotImplementedError() - - @abc.abstractmethod - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. - """ - raise NotImplementedError() - - class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS @@ -550,7 +483,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): client_name: str, server_name: str, clock: Clock, - handler: AbstractReplicationClientHandler, + command_handler: "ReplicationCommandHandler", ): BaseReplicationStreamProtocol.__init__(self, clock) @@ -558,20 +491,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.client_name = client_name self.server_name = server_name - self.handler = handler - - self.streams = { - stream.NAME: stream(hs) for stream in STREAMS_MAP.values() - } # type: Dict[str, Stream] - - # Set of stream names that have been subscribe to, but haven't yet - # caught up with. This is used to track when the client has been fully - # connected to the remote. - self.streams_connecting = set(STREAMS_MAP) # type: Set[str] - - # Map of stream to batched updates. See RdataCommand for info on how - # batching works. - self.pending_batches = {} # type: Dict[str, List[Any]] + self.handler = command_handler def connectionMade(self): self.send_command(NameCommand(self.client_name)) @@ -589,89 +509,39 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # We've now finished connecting to so inform the client handler self.handler.update_connection(self) + self.handler.finished_connecting() - async def on_SERVER(self, cmd): - if cmd.data != self.server_name: - logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) - self.send_error("Wrong remote") - - async def on_RDATA(self, cmd): - stream_name = cmd.stream_name - inbound_rdata_count.labels(stream_name).inc() - - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception( - "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row - ) - raise - - if cmd.token is None or stream_name in self.streams_connecting: - # I.e. this is part of a batch of updates for this stream. Batch - # until we get an update for the stream with a non None token - self.pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self.pending_batches.pop(stream_name, []) - rows.append(row) - await self.handler.on_rdata(stream_name, cmd.token, rows) - - async def on_POSITION(self, cmd: PositionCommand): - stream = self.streams.get(cmd.stream_name) - if not stream: - logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) - return - - # Find where we previously streamed up to. - current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) - if current_token is None: - logger.warning( - "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name - ) - return - - # Fetch all updates between then and now. - limited = True - while limited: - updates, current_token, limited = await stream.get_updates_since( - current_token, cmd.token - ) - - # Check if the connection was closed underneath us, if so we bail - # rather than risk having concurrent catch ups going on. - if self.state == ConnectionStates.CLOSED: - return - - if updates: - await self.handler.on_rdata( - cmd.stream_name, - current_token, - [stream.parse_row(update[1]) for update in updates], - ) + async def handle_command(self, cmd: Command): + """Handle a command we have received over the replication stream. - # We've now caught up to position sent to us, notify handler. - await self.handler.on_position(cmd.stream_name, cmd.token) + Delegates to `command_handler.on_`, which must return an + awaitable. - self.streams_connecting.discard(cmd.stream_name) - if not self.streams_connecting: - self.handler.finished_connecting() + Args: + cmd: received command + """ + handled = False - # Check if the connection was closed underneath us, if so we bail - # rather than risk having concurrent catch ups going on. - if self.state == ConnectionStates.CLOSED: - return + # First call any command handlers on this instance. These are for TCP + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True - # Handle any RDATA that came in while we were catching up. - rows = self.pending_batches.pop(cmd.stream_name, []) - if rows: - await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) + # Then call out to the handler. + cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True - async def on_SYNC(self, cmd): - self.handler.on_sync(cmd.data) + if not handled: + logger.warning("Unhandled command: %r", cmd) - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.handler.on_remote_server_up(cmd.data) + async def on_SERVER(self, cmd): + if cmd.data != self.server_name: + logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) + self.send_error("Wrong remote") def replicate(self): """Send the subscription request to the server @@ -768,8 +638,3 @@ tcp_outbound_commands = LaterGauge( for k, count in iteritems(p.outbound_commands_counter) }, ) - -# number of updates received for each RDATA stream -inbound_rdata_count = Counter( - "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] -) diff --git a/synapse/server.py b/synapse/server.py index 9228e1c892..9d273c980c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -87,6 +87,8 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool +from synapse.replication.tcp.client import ReplicationDataHandler +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer from synapse.rest.media.v1.media_repository import ( MediaRepository, @@ -206,6 +208,7 @@ class HomeServer(object): "password_policy_handler", "storage", "replication_streamer", + "replication_data_handler", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -468,7 +471,7 @@ class HomeServer(object): return ReadMarkerHandler(self) def build_tcp_replication(self): - raise NotImplementedError() + return ReplicationCommandHandler(self) def build_action_generator(self): return ActionGenerator(self) @@ -562,6 +565,9 @@ class HomeServer(object): def build_replication_streamer(self) -> ReplicationStreamer: return ReplicationStreamer(self) + def build_replication_data_handler(self): + return ReplicationDataHandler(self.get_datastore()) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9d1dfa71e7..9013e9bac9 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -19,6 +19,7 @@ import synapse.handlers.set_password import synapse.http.client import synapse.notifier import synapse.replication.tcp.client +import synapse.replication.tcp.handler import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender @@ -106,7 +107,11 @@ class HomeServer(object): pass def get_tcp_replication( self, - ) -> synapse.replication.tcp.client.ReplicationClientHandler: + ) -> synapse.replication.tcp.handler.ReplicationCommandHandler: + pass + def get_replication_data_handler( + self, + ) -> synapse.replication.tcp.client.ReplicationDataHandler: pass def get_federation_registry( self, diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 2a1e7c7166..8902a5ab69 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -17,8 +17,9 @@ from mock import Mock, NonCallableMock from synapse.replication.tcp.client import ( ReplicationClientFactory, - ReplicationClientHandler, + ReplicationDataHandler, ) +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.storage.database import make_conn @@ -51,15 +52,19 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer + self.streamer = hs.get_replication_streamer() - handler_factory = Mock() - self.replication_handler = ReplicationClientHandler(self.slaved_store) - self.replication_handler.factory = handler_factory + # We now do some gut wrenching so that we have a client that is based + # off of the slave store rather than the main store. + self.replication_handler = ReplicationCommandHandler(self.hs) + self.replication_handler._replication_data_handler = ReplicationDataHandler( + self.slaved_store + ) client_factory = ReplicationClientFactory( self.hs, "client_name", self.replication_handler ) + client_factory.handler = self.replication_handler server = server_factory.buildProtocol(None) client = client_factory.buildProtocol(None) diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index a755fe2879..32238fe79a 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -15,7 +15,7 @@ from mock import Mock -from synapse.replication.tcp.commands import ReplicateCommand +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -26,15 +26,20 @@ from tests.server import FakeTransport class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + def make_homeserver(self, reactor, clock): + self.test_handler = Mock(wraps=TestReplicationDataHandler()) + return self.setup_test_homeserver(replication_data_handler=self.test_handler) + def prepare(self, reactor, clock, hs): # build a replication server - server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer + server_factory = ReplicationStreamProtocolFactory(hs) + self.streamer = hs.get_replication_streamer() self.server = server_factory.buildProtocol(None) - self.test_handler = Mock(wraps=TestReplicationClientHandler()) + repl_handler = ReplicationCommandHandler(hs) + repl_handler.handler = self.test_handler self.client = ClientReplicationStreamProtocol( - hs, "client", "test", clock, self.test_handler, + hs, "client", "test", clock, repl_handler, ) self._client_transport = None @@ -69,13 +74,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.streamer.on_notifier_poke() self.pump(0.1) - def replicate_stream(self): - """Make the client end a REPLICATE command to set up a subscription to a stream""" - self.client.send_command(ReplicateCommand()) - -class TestReplicationClientHandler(object): - """Drop-in for ReplicationClientHandler which just collects RDATA rows""" +class TestReplicationDataHandler: + """Drop-in for ReplicationDataHandler which just collects RDATA rows""" def __init__(self): self.streams = set() @@ -88,18 +89,9 @@ class TestReplicationClientHandler(object): positions[stream] = max(token, positions.get(stream, 0)) return positions - def get_currently_syncing_users(self): - return [] - - def update_connection(self, connection): - pass - - def finished_connecting(self): - pass - - async def on_position(self, stream_name, token): - """Called when we get new position data.""" - async def on_rdata(self, stream_name, token, rows): for r in rows: self._received_rdata_rows.append((stream_name, token, r)) + + async def on_position(self, stream_name, token): + pass diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 0ec0825a0e..a0206f7363 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -24,7 +24,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.reconnect() # make the client subscribe to the receipts stream - self.replicate_stream() self.test_handler.streams.add("receipts") # tell the master to send a new receipt -- cgit 1.4.1 From 71a1abb8a116372556fd577ff1b85c7cfbe3c2b3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 22 Apr 2020 22:39:04 +0100 Subject: Stop the master relaying USER_SYNC for other workers (#7318) Long story short: if we're handling presence on the current worker, we shouldn't be sending USER_SYNC commands over replication. In an attempt to figure out what is going on here, I ended up refactoring some bits of the presencehandler code, so the first 4 commits here are non-functional refactors to move this code slightly closer to sanity. (There's still plenty to do here :/). Suggest reviewing individual commits. Fixes (I hope) #7257. --- changelog.d/7318.misc | 1 + docs/tcp_replication.md | 6 +- synapse/api/constants.py | 2 + synapse/app/generic_worker.py | 85 ++++++++------- synapse/handlers/events.py | 20 ++-- synapse/handlers/initial_sync.py | 10 +- synapse/handlers/presence.py | 210 ++++++++++++++++++++---------------- synapse/replication/tcp/commands.py | 7 +- synapse/replication/tcp/handler.py | 15 +-- synapse/server.pyi | 2 +- 10 files changed, 199 insertions(+), 159 deletions(-) create mode 100644 changelog.d/7318.misc (limited to 'synapse/server.pyi') diff --git a/changelog.d/7318.misc b/changelog.d/7318.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7318.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index 3be8e50c4c..b922d9cf7e 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -196,7 +196,7 @@ Asks the server for the current position of all streams. #### USER_SYNC (C) - A user has started or stopped syncing + A user has started or stopped syncing on this process. #### CLEAR_USER_SYNC (C) @@ -216,10 +216,6 @@ Asks the server for the current position of all streams. Inform the server a cache should be invalidated -#### SYNC (S, C) - - Used exclusively in tests - ### REMOTE_SERVER_UP (S, C) Inform other processes that a remote server may have come back online. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index fda2c2e5bb..bcaf2c3600 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -97,6 +97,8 @@ class EventTypes(object): Retention = "m.room.retention" + Presence = "m.presence" + class RejectedReason(object): AUTH_ERROR = "auth_error" diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 37afd2f810..2a56fe0bd5 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -17,6 +17,9 @@ import contextlib import logging import sys +from typing import Dict, Iterable + +from typing_extensions import ContextManager from twisted.internet import defer, reactor from twisted.web.resource import NoResource @@ -38,14 +41,14 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.federation import send_queue from synapse.federation.transport.server import TransportLayerServer -from synapse.handlers.presence import PresenceHandler, get_interested_parties +from synapse.handlers.presence import BasePresenceHandler, get_interested_parties from synapse.http.server import JsonResource from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.replication.slave.storage._base import BaseSlavedStore, __func__ +from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore @@ -225,23 +228,32 @@ class KeyUploadServlet(RestServlet): return 200, {"one_time_key_counts": result} +class _NullContextManager(ContextManager[None]): + """A context manager which does nothing.""" + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + UPDATE_SYNCING_USERS_MS = 10 * 1000 -class GenericWorkerPresence(object): +class GenericWorkerPresence(BasePresenceHandler): def __init__(self, hs): + super().__init__(hs) self.hs = hs self.is_mine_id = hs.is_mine_id self.http_client = hs.get_simple_http_client() - self.store = hs.get_datastore() - self.user_to_num_current_syncs = {} - self.clock = hs.get_clock() + + self._presence_enabled = hs.config.use_presence + + # The number of ongoing syncs on this process, by user id. + # Empty if _presence_enabled is false. + self._user_to_num_current_syncs = {} # type: Dict[str, int] + self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() - active_presence = self.store.take_presence_startup_info() - self.user_to_current_state = {state.user_id: state for state in active_presence} - # user_id -> last_sync_ms. Lists the users that have stopped syncing # but we haven't notified the master of that yet self.users_going_offline = {} @@ -259,13 +271,13 @@ class GenericWorkerPresence(object): ) def _on_shutdown(self): - if self.hs.config.use_presence: + if self._presence_enabled: self.hs.get_tcp_replication().send_command( ClearUserSyncsCommand(self.instance_id) ) def send_user_sync(self, user_id, is_syncing, last_sync_ms): - if self.hs.config.use_presence: + if self._presence_enabled: self.hs.get_tcp_replication().send_user_sync( self.instance_id, user_id, is_syncing, last_sync_ms ) @@ -307,28 +319,33 @@ class GenericWorkerPresence(object): # TODO Hows this supposed to work? return defer.succeed(None) - get_states = __func__(PresenceHandler.get_states) - get_state = __func__(PresenceHandler.get_state) - current_state_for_users = __func__(PresenceHandler.current_state_for_users) + async def user_syncing( + self, user_id: str, affect_presence: bool + ) -> ContextManager[None]: + """Record that a user is syncing. + + Called by the sync and events servlets to record that a user has connected to + this worker and is waiting for some events. + """ + if not affect_presence or not self._presence_enabled: + return _NullContextManager() - def user_syncing(self, user_id, affect_presence): - if affect_presence: - curr_sync = self.user_to_num_current_syncs.get(user_id, 0) - self.user_to_num_current_syncs[user_id] = curr_sync + 1 + curr_sync = self._user_to_num_current_syncs.get(user_id, 0) + self._user_to_num_current_syncs[user_id] = curr_sync + 1 - # If we went from no in flight sync to some, notify replication - if self.user_to_num_current_syncs[user_id] == 1: - self.mark_as_coming_online(user_id) + # If we went from no in flight sync to some, notify replication + if self._user_to_num_current_syncs[user_id] == 1: + self.mark_as_coming_online(user_id) def _end(): # We check that the user_id is in user_to_num_current_syncs because # user_to_num_current_syncs may have been cleared if we are # shutting down. - if affect_presence and user_id in self.user_to_num_current_syncs: - self.user_to_num_current_syncs[user_id] -= 1 + if user_id in self._user_to_num_current_syncs: + self._user_to_num_current_syncs[user_id] -= 1 # If we went from one in flight sync to non, notify replication - if self.user_to_num_current_syncs[user_id] == 0: + if self._user_to_num_current_syncs[user_id] == 0: self.mark_as_going_offline(user_id) @contextlib.contextmanager @@ -338,7 +355,7 @@ class GenericWorkerPresence(object): finally: _end() - return defer.succeed(_user_syncing()) + return _user_syncing() @defer.inlineCallbacks def notify_from_replication(self, states, stream_id): @@ -373,15 +390,12 @@ class GenericWorkerPresence(object): stream_id = token yield self.notify_from_replication(states, stream_id) - def get_currently_syncing_users(self): - if self.hs.config.use_presence: - return [ - user_id - for user_id, count in self.user_to_num_current_syncs.items() - if count > 0 - ] - else: - return set() + def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + return [ + user_id + for user_id, count in self._user_to_num_current_syncs.items() + if count > 0 + ] class GenericWorkerTyping(object): @@ -625,8 +639,7 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): self.store = hs.get_datastore() self.typing_handler = hs.get_typing_handler() - # NB this is a SynchrotronPresence, not a normal PresenceHandler - self.presence_handler = hs.get_presence_handler() + self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence self.notifier = hs.get_notifier() self.notify_pushers = hs.config.start_pushers diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index ec18a42a68..71a89f09c7 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -19,6 +19,7 @@ import random from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase +from synapse.handlers.presence import format_user_presence_state from synapse.logging.utils import log_function from synapse.types import UserID from synapse.visibility import filter_events_for_client @@ -97,6 +98,8 @@ class EventStreamHandler(BaseHandler): explicit_room_id=room_id, ) + time_now = self.clock.time_msec() + # When the user joins a new room, or another user joins a currently # joined room, we need to send down presence for those users. to_add = [] @@ -112,19 +115,20 @@ class EventStreamHandler(BaseHandler): users = await self.state.get_current_users_in_room( event.room_id ) - states = await presence_handler.get_states(users, as_event=True) - to_add.extend(states) else: + users = [event.state_key] - ev = await presence_handler.get_state( - UserID.from_string(event.state_key), as_event=True - ) - to_add.append(ev) + states = await presence_handler.get_states(users) + to_add.extend( + { + "type": EventTypes.Presence, + "content": format_user_presence_state(state, time_now), + } + for state in states + ) events.extend(to_add) - time_now = self.clock.time_msec() - chunks = await self._event_serializer.serialize_events( events, time_now, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index b116500c7d..f88bad5f25 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -381,10 +381,16 @@ class InitialSyncHandler(BaseHandler): return [] states = await presence_handler.get_states( - [m.user_id for m in room_members], as_event=True + [m.user_id for m in room_members] ) - return states + return [ + { + "type": EventTypes.Presence, + "content": format_user_presence_state(s, time_now), + } + for s in states + ] async def get_receipts(): receipts = await self.store.get_linearized_receipts_for_room( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 6912165622..5cbefae177 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,10 +22,10 @@ The methods that define policy are: - PresenceHandler._handle_timeouts - should_notify """ - +import abc import logging from contextlib import contextmanager -from typing import Dict, List, Set +from typing import Dict, Iterable, List, Set from six import iteritems, itervalues @@ -41,7 +42,7 @@ from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState -from synapse.types import UserID, get_domain_from_id +from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure @@ -99,13 +100,106 @@ EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000 assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER -class PresenceHandler(object): +class BasePresenceHandler(abc.ABC): + """Parts of the PresenceHandler that are shared between workers and master""" + + def __init__(self, hs: "synapse.server.HomeServer"): + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + active_presence = self.store.take_presence_startup_info() + self.user_to_current_state = {state.user_id: state for state in active_presence} + + @abc.abstractmethod + async def user_syncing( + self, user_id: str, affect_presence: bool + ) -> ContextManager[None]: + """Returns a context manager that should surround any stream requests + from the user. + + This allows us to keep track of who is currently streaming and who isn't + without having to have timers outside of this module to avoid flickering + when users disconnect/reconnect. + + Args: + user_id: the user that is starting a sync + affect_presence: If false this function will be a no-op. + Useful for streams that are not associated with an actual + client that is being used by a user. + """ + + @abc.abstractmethod + def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + """Get an iterable of syncing users on this worker, to send to the presence handler + + This is called when a replication connection is established. It should return + a list of user ids, which are then sent as USER_SYNC commands to inform the + process handling presence about those users. + + Returns: + An iterable of user_id strings. + """ + + async def get_state(self, target_user: UserID) -> UserPresenceState: + results = await self.get_states([target_user.to_string()]) + return results[0] + + async def get_states( + self, target_user_ids: Iterable[str] + ) -> List[UserPresenceState]: + """Get the presence state for users.""" + + updates_d = await self.current_state_for_users(target_user_ids) + updates = list(updates_d.values()) + + for user_id in set(target_user_ids) - {u.user_id for u in updates}: + updates.append(UserPresenceState.default(user_id)) + + return updates + + async def current_state_for_users( + self, user_ids: Iterable[str] + ) -> Dict[str, UserPresenceState]: + """Get the current presence state for multiple users. + + Returns: + dict: `user_id` -> `UserPresenceState` + """ + states = { + user_id: self.user_to_current_state.get(user_id, None) + for user_id in user_ids + } + + missing = [user_id for user_id, state in iteritems(states) if not state] + if missing: + # There are things not in our in memory cache. Lets pull them out of + # the database. + res = await self.store.get_presence_for_users(missing) + states.update(res) + + missing = [user_id for user_id, state in iteritems(states) if not state] + if missing: + new = { + user_id: UserPresenceState.default(user_id) for user_id in missing + } + states.update(new) + self.user_to_current_state.update(new) + + return states + + @abc.abstractmethod + async def set_state( + self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False + ) -> None: + """Set the presence state of the user. """ + + +class PresenceHandler(BasePresenceHandler): def __init__(self, hs: "synapse.server.HomeServer"): + super().__init__(hs) self.hs = hs self.is_mine_id = hs.is_mine_id self.server_name = hs.hostname - self.clock = hs.get_clock() - self.store = hs.get_datastore() self.wheel_timer = WheelTimer() self.notifier = hs.get_notifier() self.federation = hs.get_federation_sender() @@ -115,13 +209,6 @@ class PresenceHandler(object): federation_registry.register_edu_handler("m.presence", self.incoming_presence) - active_presence = self.store.take_presence_startup_info() - - # A dictionary of the current state of users. This is prefilled with - # non-offline presence from the DB. We should fetch from the DB if - # we can't find a users presence in here. - self.user_to_current_state = {state.user_id: state for state in active_presence} - LaterGauge( "synapse_handlers_presence_user_to_current_state_size", "", @@ -130,7 +217,7 @@ class PresenceHandler(object): ) now = self.clock.time_msec() - for state in active_presence: + for state in self.user_to_current_state.values(): self.wheel_timer.insert( now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER ) @@ -361,10 +448,18 @@ class PresenceHandler(object): timers_fired_counter.inc(len(states)) + syncing_user_ids = { + user_id + for user_id, count in self.user_to_num_current_syncs.items() + if count + } + for user_ids in self.external_process_to_current_syncs.values(): + syncing_user_ids.update(user_ids) + changes = handle_timeouts( states, is_mine_fn=self.is_mine_id, - syncing_user_ids=self.get_currently_syncing_users(), + syncing_user_ids=syncing_user_ids, now=now, ) @@ -462,22 +557,9 @@ class PresenceHandler(object): return _user_syncing() - def get_currently_syncing_users(self): - """Get the set of user ids that are currently syncing on this HS. - Returns: - set(str): A set of user_id strings. - """ - if self.hs.config.use_presence: - syncing_user_ids = { - user_id - for user_id, count in self.user_to_num_current_syncs.items() - if count - } - for user_ids in self.external_process_to_current_syncs.values(): - syncing_user_ids.update(user_ids) - return syncing_user_ids - else: - return set() + def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + # since we are the process handling presence, there is nothing to do here. + return [] async def update_external_syncs_row( self, process_id, user_id, is_syncing, sync_time_msec @@ -554,34 +636,6 @@ class PresenceHandler(object): res = await self.current_state_for_users([user_id]) return res[user_id] - async def current_state_for_users(self, user_ids): - """Get the current presence state for multiple users. - - Returns: - dict: `user_id` -> `UserPresenceState` - """ - states = { - user_id: self.user_to_current_state.get(user_id, None) - for user_id in user_ids - } - - missing = [user_id for user_id, state in iteritems(states) if not state] - if missing: - # There are things not in our in memory cache. Lets pull them out of - # the database. - res = await self.store.get_presence_for_users(missing) - states.update(res) - - missing = [user_id for user_id, state in iteritems(states) if not state] - if missing: - new = { - user_id: UserPresenceState.default(user_id) for user_id in missing - } - states.update(new) - self.user_to_current_state.update(new) - - return states - async def _persist_and_notify(self, states): """Persist states in the database, poke the notifier and send to interested remote servers @@ -669,40 +723,6 @@ class PresenceHandler(object): federation_presence_counter.inc(len(updates)) await self._update_states(updates) - async def get_state(self, target_user, as_event=False): - results = await self.get_states([target_user.to_string()], as_event=as_event) - - return results[0] - - async def get_states(self, target_user_ids, as_event=False): - """Get the presence state for users. - - Args: - target_user_ids (list) - as_event (bool): Whether to format it as a client event or not. - - Returns: - list - """ - - updates = await self.current_state_for_users(target_user_ids) - updates = list(updates.values()) - - for user_id in set(target_user_ids) - {u.user_id for u in updates}: - updates.append(UserPresenceState.default(user_id)) - - now = self.clock.time_msec() - if as_event: - return [ - { - "type": "m.presence", - "content": format_user_presence_state(state, now), - } - for state in updates - ] - else: - return updates - async def set_state(self, target_user, state, ignore_status_msg=False): """Set the presence state of the user. """ @@ -889,7 +909,7 @@ class PresenceHandler(object): user_ids = await self.state.get_current_users_in_room(room_id) user_ids = list(filter(self.is_mine_id, user_ids)) - states = await self.current_state_for_users(user_ids) + states_d = await self.current_state_for_users(user_ids) # Filter out old presence, i.e. offline presence states where # the user hasn't been active for a week. We can change this @@ -899,7 +919,7 @@ class PresenceHandler(object): now = self.clock.time_msec() states = [ state - for state in states.values() + for state in states_d.values() if state.state != PresenceState.OFFLINE or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000 or state.status_msg is not None diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index f26aee83cb..c7880d4b63 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -210,7 +210,10 @@ class ReplicateCommand(Command): class UserSyncCommand(Command): """Sent by the client to inform the server that a user has started or - stopped syncing. Used to calculate presence on the master. + stopped syncing on this process. + + This is used by the process handling presence (typically the master) to + calculate who is online and who is not. Includes a timestamp of when the last user sync was. @@ -218,7 +221,7 @@ class UserSyncCommand(Command): USER_SYNC - Where is either "start" or "stop" + Where is either "start" or "end" """ NAME = "USER_SYNC" diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 5b5ee2c13e..0db5a3a24d 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -337,13 +337,6 @@ class ReplicationCommandHandler: if self._is_master: self._notifier.notify_remote_server_up(cmd.data) - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users. - """ - return self._presence_handler.get_currently_syncing_users() - def new_connection(self, connection: AbstractConnection): """Called when we have a new connection. """ @@ -361,9 +354,11 @@ class ReplicationCommandHandler: if self._factory: self._factory.resetDelay() - # Tell the server if we have any users currently syncing (should only - # happen on synchrotrons) - currently_syncing = self.get_currently_syncing_users() + # Tell the other end if we have any users currently syncing. + currently_syncing = ( + self._presence_handler.get_currently_syncing_users_for_replication() + ) + now = self._clock.time_msec() for user_id in currently_syncing: connection.send_command( diff --git a/synapse/server.pyi b/synapse/server.pyi index 9013e9bac9..f1a5717028 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -97,7 +97,7 @@ class HomeServer(object): pass def get_notifier(self) -> synapse.notifier.Notifier: pass - def get_presence_handler(self) -> synapse.handlers.presence.PresenceHandler: + def get_presence_handler(self) -> synapse.handlers.presence.BasePresenceHandler: pass def get_clock(self) -> synapse.util.Clock: pass -- cgit 1.4.1 From c2e1a2110fbe9ead26b4ecbb1afd504ed035a04d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 29 Apr 2020 12:30:36 +0100 Subject: Fix limit logic for EventsStream (#7358) * Factor out functions for injecting events into database I want to add some more flexibility to the tools for injecting events into the database, and I don't want to clutter up HomeserverTestCase with them, so let's factor them out to a new file. * Rework TestReplicationDataHandler This wasn't very easy to work with: the mock wrapping was largely superfluous, and it's useful to be able to inspect the received rows, and clear out the received list. * Fix AssertionErrors being thrown by EventsStream Part of the problem was that there was an off-by-one error in the assertion, but also the limit logic was too simple. Fix it all up and add some tests. --- changelog.d/7358.bugfix | 1 + synapse/replication/tcp/handler.py | 4 +- synapse/replication/tcp/streams/events.py | 22 +- synapse/server.pyi | 5 + synapse/storage/data_stores/main/events_worker.py | 64 +++- tests/replication/tcp/streams/_base.py | 41 ++- tests/replication/tcp/streams/test_events.py | 417 ++++++++++++++++++++++ tests/replication/tcp/streams/test_receipts.py | 10 +- tests/replication/tcp/streams/test_typing.py | 11 +- tests/rest/client/v1/utils.py | 2 +- tests/test_utils/__init__.py | 20 ++ tests/test_utils/event_injection.py | 96 +++++ tests/unittest.py | 30 +- tox.ini | 2 + 14 files changed, 658 insertions(+), 67 deletions(-) create mode 100644 changelog.d/7358.bugfix create mode 100644 tests/replication/tcp/streams/test_events.py create mode 100644 tests/test_utils/event_injection.py (limited to 'synapse/server.pyi') diff --git a/changelog.d/7358.bugfix b/changelog.d/7358.bugfix new file mode 100644 index 0000000000..f49c600173 --- /dev/null +++ b/changelog.d/7358.bugfix @@ -0,0 +1 @@ +Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 0db5a3a24d..3a8c7c7e2d 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -87,7 +87,9 @@ class ReplicationCommandHandler: stream.NAME: stream(hs) for stream in STREAMS_MAP.values() } # type: Dict[str, Stream] - self._position_linearizer = Linearizer("replication_position") + self._position_linearizer = Linearizer( + "replication_position", clock=self._clock + ) # Map of stream to batched updates. See RdataCommand for info on how # batching works. diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index aa50492569..52df81b1bd 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -170,22 +170,16 @@ class EventsStream(Stream): limited = False upper_limit = current_token - # next up is the state delta table - - state_rows = await self._store.get_all_updated_current_state_deltas( + # next up is the state delta table. + ( + state_rows, + upper_limit, + state_rows_limited, + ) = await self._store.get_all_updated_current_state_deltas( from_token, upper_limit, target_row_count - ) # type: List[Tuple] - - # again, if we've hit the limit there, we'll need to limit the other sources - assert len(state_rows) < target_row_count - if len(state_rows) == target_row_count: - assert state_rows[-1][0] <= upper_limit - upper_limit = state_rows[-1][0] - limited = True + ) - # FIXME: is it a given that there is only one row per stream_id in the - # state_deltas table (so that we can be sure that we have got all of the - # rows for upper_limit)? + limited = limited or state_rows_limited # finally, fetch the ex-outliers rows. We assume there are few enough of these # not to bother with the limit. diff --git a/synapse/server.pyi b/synapse/server.pyi index f1a5717028..fc5886f762 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage +from synapse.events.builder import EventBuilderFactory class HomeServer(object): @property @@ -121,3 +122,7 @@ class HomeServer(object): pass def get_instance_id(self) -> str: pass + def get_event_builder_factory(self) -> EventBuilderFactory: + pass + def get_storage(self) -> synapse.storage.Storage: + pass diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index ce8be72bfe..73df6b33ba 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -19,7 +19,7 @@ import itertools import logging import threading from collections import namedtuple -from typing import List, Optional +from typing import List, Optional, Tuple from canonicaljson import json from constantly import NamedConstant, Names @@ -1084,7 +1084,28 @@ class EventsWorkerStore(SQLBaseStore): "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) - def get_all_updated_current_state_deltas(self, from_token, to_token, limit): + async def get_all_updated_current_state_deltas( + self, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple], int, bool]: + """Fetch updates from current_state_delta_stream + + Args: + from_token: The previous stream token. Updates from this stream id will + be excluded. + + to_token: The current stream token (ie the upper limit). Updates up to this + stream id will be included (modulo the 'limit' param) + + target_row_count: The number of rows to try to return. If more rows are + available, we will set 'limited' in the result. In the event of a large + batch, we may return more rows than this. + Returns: + A triplet `(updates, new_last_token, limited)`, where: + * `updates` is a list of database tuples. + * `new_last_token` is the new position in stream. + * `limited` is whether there are more updates to fetch. + """ + def get_all_updated_current_state_deltas_txn(txn): sql = """ SELECT stream_id, room_id, type, state_key, event_id @@ -1092,10 +1113,45 @@ class EventsWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ - txn.execute(sql, (from_token, to_token, limit)) + txn.execute(sql, (from_token, to_token, target_row_count)) return txn.fetchall() - return self.db.runInteraction( + def get_deltas_for_stream_id_txn(txn, stream_id): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE stream_id = ? + """ + txn.execute(sql, [stream_id]) + return txn.fetchall() + + # we need to make sure that, for every stream id in the results, we get *all* + # the rows with that stream id. + + rows = await self.db.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, + ) # type: List[Tuple] + + # if we've got fewer rows than the limit, we're good + if len(rows) < target_row_count: + return rows, to_token, False + + # we hit the limit, so reduce the upper limit so that we exclude the stream id + # of the last row in the result. + assert rows[-1][0] <= to_token + to_token = rows[-1][0] - 1 + + # search backwards through the list for the point to truncate + for idx in range(len(rows) - 1, 0, -1): + if rows[idx - 1][0] <= to_token: + return rows[:idx], to_token, True + + # bother. We didn't get a full set of changes for even a single + # stream id. let's run the query again, without a row limit, but for + # just one stream id. + to_token += 1 + rows = await self.db.runInteraction( + "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token ) + return rows, to_token, True diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 82f15c64e0..83e16cfe3d 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -12,10 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -from typing import Optional -from mock import Mock +import logging +from typing import Any, Dict, List, Optional, Tuple import attr @@ -25,6 +24,7 @@ from twisted.web.http import HTTPChannel from synapse.app.generic_worker import GenericWorkerServer from synapse.http.site import SynapseRequest +from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol @@ -65,9 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # databases objects are the same. self.worker_hs.get_datastore().db = hs.get_datastore().db - self.test_handler = Mock( - wraps=TestReplicationDataHandler(self.worker_hs.get_datastore()) - ) + self.test_handler = self._build_replication_data_handler() self.worker_hs.replication_data_handler = self.test_handler repl_handler = ReplicationCommandHandler(self.worker_hs) @@ -78,6 +76,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._client_transport = None self._server_transport = None + def _build_replication_data_handler(self): + return TestReplicationDataHandler(self.worker_hs.get_datastore()) + def reconnect(self): if self._client_transport: self.client.close() @@ -174,22 +175,28 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): class TestReplicationDataHandler(ReplicationDataHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" - def __init__(self, hs): - super().__init__(hs) - self.streams = set() - self._received_rdata_rows = [] + def __init__(self, store: BaseSlavedStore): + super().__init__(store) + + # streams to subscribe to: map from stream id to position + self.stream_positions = {} # type: Dict[str, int] + + # list of received (stream_name, token, row) tuples + self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] def get_streams_to_replicate(self): - positions = {s: 0 for s in self.streams} - for stream, token, _ in self._received_rdata_rows: - if stream in self.streams: - positions[stream] = max(token, positions.get(stream, 0)) - return positions + return self.stream_positions async def on_rdata(self, stream_name, token, rows): await super().on_rdata(stream_name, token, rows) for r in rows: - self._received_rdata_rows.append((stream_name, token, r)) + self.received_rdata_rows.append((stream_name, token, r)) + + if ( + stream_name in self.stream_positions + and token > self.stream_positions[stream_name] + ): + self.stream_positions[stream_name] = token @attr.s() @@ -221,7 +228,7 @@ class _PushHTTPChannel(HTTPChannel): super().__init__() self.reactor = reactor - self._pull_to_push_producer = None + self._pull_to_push_producer = None # type: Optional[_PullToPushProducer] def registerProducer(self, producer, streaming): # Convert pull producers to push producer. diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py new file mode 100644 index 0000000000..1fa28084f9 --- /dev/null +++ b/tests/replication/tcp/streams/test_events.py @@ -0,0 +1,417 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from synapse.api.constants import EventTypes, Membership +from synapse.events import EventBase +from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT +from synapse.replication.tcp.streams.events import ( + EventsStreamCurrentStateRow, + EventsStreamEventRow, + EventsStreamRow, +) +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.replication.tcp.streams._base import BaseStreamTestCase +from tests.test_utils.event_injection import inject_event, inject_member_event + + +class EventsStreamTestCase(BaseStreamTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + super().prepare(reactor, clock, hs) + self.user_id = self.register_user("u1", "pass") + self.user_tok = self.login("u1", "pass") + + self.reconnect() + self.test_handler.stream_positions["events"] = 0 + + self.room_id = self.helper.create_room_as(tok=self.user_tok) + self.test_handler.received_rdata_rows.clear() + + def test_update_function_event_row_limit(self): + """Test replication with many non-state events + + Checks that all events are correctly replicated when there are lots of + event rows to be replicated. + """ + # disconnect, so that we can stack up some changes + self.disconnect() + + # generate lots of non-state events. We inject them using inject_event + # so that they are not send out over replication until we call self.replicate(). + events = [ + self._inject_test_event() + for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1) + ] + + # also one state event + state_event = self._inject_state_event() + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + received_rows = self.test_handler.received_rdata_rows + for event in events: + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, event.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state_event.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state_event.event_id) + + self.assertEqual([], received_rows) + + def test_update_function_huge_state_change(self): + """Test replication with many state events + + Ensures that all events are correctly replicated when there are lots of + state change rows to be replicated. + """ + + # we want to generate lots of state changes at a single stream ID. + # + # We do this by having two branches in the DAG. On one, we have a moderator + # which that generates lots of state; on the other, we de-op the moderator, + # thus invalidating all the state. + + OTHER_USER = "@other_user:localhost" + + # have the user join + inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + + # Update existing power levels with mod at PL50 + pls = self.helper.get_state( + self.room_id, EventTypes.PowerLevels, tok=self.user_tok + ) + pls["users"][OTHER_USER] = 50 + self.helper.send_state( + self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + ) + + # this is the point in the DAG where we make a fork + fork_point = self.get_success( + self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + ) # type: List[str] + + events = [ + self._inject_state_event(sender=OTHER_USER) + for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT) + ] + + self.replicate() + # all those events and state changes should have landed + self.assertGreaterEqual( + len(self.test_handler.received_rdata_rows), 2 * len(events) + ) + + # disconnect, so that we can stack up the changes + self.disconnect() + self.test_handler.received_rdata_rows.clear() + + # a state event which doesn't get rolled back, to check that the state + # before the huge update comes through ok + state1 = self._inject_state_event() + + # roll back all the state by de-modding the user + prev_events = fork_point + pls["users"][OTHER_USER] = 0 + pl_event = inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) + + # one more bit of state that doesn't get rolled back + state2 = self._inject_state_event() + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # now we should have received all the expected rows in the right order. + # + # we expect: + # + # - two rows for state1 + # - the PL event row, plus state rows for the PL event and each + # of the states that got reverted. + # - two rows for state2 + + received_rows = self.test_handler.received_rdata_rows + + # first check the first two rows, which should be state1 + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state1.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state1.event_id) + + # now the last two rows, which should be state2 + stream_name, token, row = received_rows.pop(-2) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state2.event_id) + + stream_name, token, row = received_rows.pop(-1) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state2.event_id) + + # that should leave us with the rows for the PL event + self.assertEqual(len(received_rows), len(events) + 2) + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_event.event_id) + + # the state rows are unsorted + state_rows = [] # type: List[EventsStreamCurrentStateRow] + for stream_name, token, row in received_rows: + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_event.event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) + + def test_update_function_state_row_limit(self): + """Test replication with many state events over several stream ids. + """ + + # we want to generate lots of state changes, but for this test, we want to + # spread out the state changes over a few stream IDs. + # + # We do this by having two branches in the DAG. On one, we have four moderators, + # each of which that generates lots of state; on the other, we de-op the users, + # thus invalidating all the state. + + NUM_USERS = 4 + STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1 + + user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)] + + # have the users join + for u in user_ids: + inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + + # Update existing power levels with mod at PL50 + pls = self.helper.get_state( + self.room_id, EventTypes.PowerLevels, tok=self.user_tok + ) + pls["users"].update({u: 50 for u in user_ids}) + self.helper.send_state( + self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + ) + + # this is the point in the DAG where we make a fork + fork_point = self.get_success( + self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + ) # type: List[str] + + events = [] # type: List[EventBase] + for user in user_ids: + events.extend( + self._inject_state_event(sender=user) for _ in range(STATES_PER_USER) + ) + + self.replicate() + + # all those events and state changes should have landed + self.assertGreaterEqual( + len(self.test_handler.received_rdata_rows), 2 * len(events) + ) + + # disconnect, so that we can stack up the changes + self.disconnect() + self.test_handler.received_rdata_rows.clear() + + # now roll back all that state by de-modding the users + prev_events = fork_point + pl_events = [] + for u in user_ids: + pls["users"][u] = 0 + e = inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) + prev_events = [e.event_id] + pl_events.append(e) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + + received_rows = self.test_handler.received_rdata_rows + self.assertGreaterEqual(len(received_rows), len(events)) + for i in range(NUM_USERS): + # for each user, we expect the PL event row, followed by state rows for + # the PL event and each of the states that got reverted. + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_events[i].event_id) + + # the state rows are unsorted + state_rows = [] # type: List[EventsStreamCurrentStateRow] + for j in range(STATES_PER_USER + 1): + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_events[i].event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) + + self.assertEqual([], received_rows) + + event_count = 0 + + def _inject_test_event( + self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs + ) -> EventBase: + if sender is None: + sender = self.user_id + + if body is None: + body = "event %i" % (self.event_count,) + self.event_count += 1 + + return inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_event", + content={"body": body}, + **kwargs + ) + + def _inject_state_event( + self, + body: Optional[str] = None, + state_key: Optional[str] = None, + sender: Optional[str] = None, + ) -> EventBase: + if sender is None: + sender = self.user_id + + if state_key is None: + state_key = "state_%i" % (self.event_count,) + self.event_count += 1 + + if body is None: + body = "state event %s" % (state_key,) + + return inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_state_event", + state_key=state_key, + content={"body": body}, + ) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index a0206f7363..c122b8589c 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -12,6 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# type: ignore + +from mock import Mock + from synapse.replication.tcp.streams._base import ReceiptsStream from tests.replication.tcp.streams._base import BaseStreamTestCase @@ -20,11 +25,14 @@ USER_ID = "@feeling:blue" class ReceiptsStreamTestCase(BaseStreamTestCase): + def _build_replication_data_handler(self): + return Mock(wraps=super()._build_replication_data_handler()) + def test_receipt(self): self.reconnect() # make the client subscribe to the receipts stream - self.test_handler.streams.add("receipts") + self.test_handler.stream_positions.update({"receipts": 0}) # tell the master to send a new receipt self.get_success( diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index f0ad6402ae..4d354a9db8 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from synapse.handlers.typing import RoomMember from synapse.replication.http import streams from synapse.replication.tcp.streams import TypingStream @@ -26,6 +28,9 @@ class TypingStreamTestCase(BaseStreamTestCase): streams.register_servlets, ] + def _build_replication_data_handler(self): + return Mock(wraps=super()._build_replication_data_handler()) + def test_typing(self): typing = self.hs.get_typing_handler() @@ -33,8 +38,8 @@ class TypingStreamTestCase(BaseStreamTestCase): self.reconnect() - # make the client subscribe to the receipts stream - self.test_handler.streams.add("typing") + # make the client subscribe to the typing stream + self.test_handler.stream_positions.update({"typing": 0}) typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) @@ -75,6 +80,6 @@ class TypingStreamTestCase(BaseStreamTestCase): stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: TypingStream.TypingStreamRow + row = rdata_rows[0] self.assertEqual(room_id, row.room_id) self.assertEqual([], row.user_ids) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 371637618d..22d734e763 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -39,7 +39,7 @@ class RestHelper(object): resource = attr.ib() auth_user_id = attr.ib() - def create_room_as(self, room_creator, is_public=True, tok=None): + def create_room_as(self, room_creator=None, is_public=True, tok=None): temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index a7310cf12a..7b345b03bb 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2019 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,3 +17,22 @@ """ Utilities for running the unit tests """ +from typing import Awaitable, TypeVar + +TV = TypeVar("TV") + + +def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: + """Get the result from an Awaitable which should have completed + + Asserts that the given awaitable has a result ready, and returns its value + """ + i = awaitable.__await__() + try: + next(i) + except StopIteration as e: + # awaitable returned a result + return e.value + + # if next didn't raise, the awaitable hasn't completed. + raise Exception("awaitable has not yet completed") diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py new file mode 100644 index 0000000000..8f6872761a --- /dev/null +++ b/tests/test_utils/event_injection.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import synapse.server +from synapse.api.constants import EventTypes +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase +from synapse.types import Collection + +from tests.test_utils import get_awaitable_result + + +""" +Utility functions for poking events into the storage of the server under test. +""" + + +def inject_member_event( + hs: synapse.server.HomeServer, + room_id: str, + sender: str, + membership: str, + target: Optional[str] = None, + extra_content: Optional[dict] = None, + **kwargs +) -> EventBase: + """Inject a membership event into a room.""" + if target is None: + target = sender + + content = {"membership": membership} + if extra_content: + content.update(extra_content) + + return inject_event( + hs, + room_id=room_id, + type=EventTypes.Member, + sender=sender, + state_key=target, + content=content, + **kwargs + ) + + +def inject_event( + hs: synapse.server.HomeServer, + room_version: Optional[str] = None, + prev_event_ids: Optional[Collection[str]] = None, + **kwargs +) -> EventBase: + """Inject a generic event into a room + + Args: + hs: the homeserver under test + room_version: the version of the room we're inserting into. + if not specified, will be looked up + prev_event_ids: prev_events for the event. If not specified, will be looked up + kwargs: fields for the event to be created + """ + test_reactor = hs.get_reactor() + + if room_version is None: + d = hs.get_datastore().get_room_version_id(kwargs["room_id"]) + test_reactor.advance(0) + room_version = get_awaitable_result(d) + + builder = hs.get_event_builder_factory().for_room_version( + KNOWN_ROOM_VERSIONS[room_version], kwargs + ) + d = hs.get_event_creation_handler().create_new_client_event( + builder, prev_event_ids=prev_event_ids + ) + test_reactor.advance(0) + event, context = get_awaitable_result(d) + + d = hs.get_storage().persistence.persist_event(event, context) + test_reactor.advance(0) + get_awaitable_result(d) + + return event diff --git a/tests/unittest.py b/tests/unittest.py index 27af5228fe..6b6f224e9c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -32,7 +32,6 @@ from twisted.python.threadpool import ThreadPool from twisted.trial import unittest from synapse.api.constants import EventTypes, Membership -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport import server as federation_server @@ -55,6 +54,7 @@ from tests.server import ( render, setup_test_homeserver, ) +from tests.test_utils import event_injection from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -596,36 +596,14 @@ class HomeserverTestCase(TestCase): """ Inject a membership event into a room. + Deprecated: use event_injection.inject_room_member directly + Args: room: Room ID to inject the event into. user: MXID of the user to inject the membership for. membership: The membership type. """ - event_builder_factory = self.hs.get_event_builder_factory() - event_creation_handler = self.hs.get_event_creation_handler() - - room_version = self.get_success( - self.hs.get_datastore().get_room_version_id(room) - ) - - builder = event_builder_factory.for_room_version( - KNOWN_ROOM_VERSIONS[room_version], - { - "type": EventTypes.Member, - "sender": user, - "state_key": user, - "room_id": room, - "content": {"membership": membership}, - }, - ) - - event, context = self.get_success( - event_creation_handler.create_new_client_event(builder) - ) - - self.get_success( - self.hs.get_storage().persistence.persist_event(event, context) - ) + event_injection.inject_member_event(self.hs, room, user, membership) class FederatingHomeserverTestCase(HomeserverTestCase): diff --git a/tox.ini b/tox.ini index 31011d7436..2630857436 100644 --- a/tox.ini +++ b/tox.ini @@ -204,6 +204,8 @@ commands = mypy \ synapse/storage/database.py \ synapse/streams \ synapse/util/caches/stream_change_cache.py \ + tests/replication/tcp/streams \ + tests/test_utils \ tests/util/test_stream_change_cache.py # To find all folders that pass mypy you run: -- cgit 1.4.1 From 37f6823f5b91f27b9dd8de8fc0e52d5ea889647c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 Apr 2020 16:23:08 +0100 Subject: Add instance name to RDATA/POSITION commands (#7364) This is primarily for allowing us to send those commands from workers, but for now simply allows us to ignore echoed RDATA/POSITION commands that we sent (we get echoes of sent commands when using redis). Currently we log a WARNING on the master process every time we receive an echoed RDATA. --- changelog.d/7364.misc | 1 + docs/tcp_replication.md | 41 +++++++++++++++++++------------- synapse/app/_base.py | 4 ++-- synapse/logging/opentracing.py | 23 ++++++++---------- synapse/replication/tcp/commands.py | 37 +++++++++++++++++++--------- synapse/replication/tcp/handler.py | 17 ++++++++++--- synapse/server.py | 13 ++++++++-- synapse/server.pyi | 2 ++ tests/replication/slave/storage/_base.py | 1 + tests/replication/tcp/test_commands.py | 6 +++-- 10 files changed, 95 insertions(+), 50 deletions(-) create mode 100644 changelog.d/7364.misc (limited to 'synapse/server.pyi') diff --git a/changelog.d/7364.misc b/changelog.d/7364.misc new file mode 100644 index 0000000000..bb5d727cf4 --- /dev/null +++ b/changelog.d/7364.misc @@ -0,0 +1 @@ +Add an `instance_name` to `RDATA` and `POSITION` replication commands. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index b922d9cf7e..ab2fffbfe4 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -15,15 +15,17 @@ example flow would be (where '>' indicates master to worker and > SERVER example.com < REPLICATE - > POSITION events 53 - > RDATA events 54 ["$foo1:bar.com", ...] - > RDATA events 55 ["$foo4:bar.com", ...] + > POSITION events master 53 + > RDATA events master 54 ["$foo1:bar.com", ...] + > RDATA events master 55 ["$foo4:bar.com", ...] The example shows the server accepting a new connection and sending its identity with the `SERVER` command, followed by the client server to respond with the position of all streams. The server then periodically sends `RDATA` commands -which have the format `RDATA `, where the format of -`` is defined by the individual streams. +which have the format `RDATA `, where +the format of `` is defined by the individual streams. The +`` is the name of the Synapse process that generated the data +(usually "master"). Error reporting happens by either the client or server sending an ERROR command, and usually the connection will be closed. @@ -52,7 +54,7 @@ The basic structure of the protocol is line based, where the initial word of each line specifies the command. The rest of the line is parsed based on the command. For example, the RDATA command is defined as: - RDATA + RDATA (Note that may contains spaces, but cannot contain newlines.) @@ -136,11 +138,11 @@ the wire: < NAME synapse.app.appservice < PING 1490197665618 < REPLICATE - > POSITION events 1 - > POSITION backfill 1 - > POSITION caches 1 - > RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] - > RDATA events 14 ["$149019767112vOHxz:localhost:8823", + > POSITION events master 1 + > POSITION backfill master 1 + > POSITION caches master 1 + > RDATA caches master 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] + > RDATA events master 14 ["$149019767112vOHxz:localhost:8823", "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null] < PING 1490197675618 > ERROR server stopping @@ -151,10 +153,10 @@ position without needing to send data with the `RDATA` command. An example of a batched set of `RDATA` is: - > RDATA caches batch ["get_user_by_id",["@test:localhost:8823"],1490197670513] - > RDATA caches batch ["get_user_by_id",["@test2:localhost:8823"],1490197670513] - > RDATA caches batch ["get_user_by_id",["@test3:localhost:8823"],1490197670513] - > RDATA caches 54 ["get_user_by_id",["@test4:localhost:8823"],1490197670513] + > RDATA caches master batch ["get_user_by_id",["@test:localhost:8823"],1490197670513] + > RDATA caches master batch ["get_user_by_id",["@test2:localhost:8823"],1490197670513] + > RDATA caches master batch ["get_user_by_id",["@test3:localhost:8823"],1490197670513] + > RDATA caches master 54 ["get_user_by_id",["@test4:localhost:8823"],1490197670513] In this case the client shouldn't advance their caches token until it sees the the last `RDATA`. @@ -178,6 +180,11 @@ client (C): updates, and if so then fetch them out of band. Sent in response to a REPLICATE command (but can happen at any time). + The POSITION command includes the source of the stream. Currently all streams + are written by a single process (usually "master"). If fetching missing + updates via HTTP API, rather than via the DB, then processes should make the + request to the appropriate process. + #### ERROR (S, C) There was an error @@ -234,12 +241,12 @@ Each individual cache invalidation results in a row being sent down replication, which includes the cache name (the name of the function) and they key to invalidate. For example: - > RDATA caches 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251] + > RDATA caches master 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251] Alternatively, an entire cache can be invalidated by sending down a `null` instead of the key. For example: - > RDATA caches 550953772 ["get_user_by_id", null, 1550574873252] + > RDATA caches master 550953772 ["get_user_by_id", null, 1550574873252] However, there are times when a number of caches need to be invalidated at the same time with the same key. To reduce traffic we batch those diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 4d84f4595a..628292b890 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -270,7 +270,7 @@ def start(hs, listeners=None): # Start the tracer synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa - hs.config + hs ) # It is now safe to start your Synapse. @@ -316,7 +316,7 @@ def setup_sentry(hs): scope.set_tag("matrix_server_name", hs.config.server_name) app = hs.config.worker_app if hs.config.worker_app else "synapse.app.homeserver" - name = hs.config.worker_name if hs.config.worker_name else "master" + name = hs.get_instance_name() scope.set_tag("worker_app", app) scope.set_tag("worker_name", name) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 0638cec429..5dddf57008 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -171,7 +171,7 @@ import logging import re import types from functools import wraps -from typing import Dict +from typing import TYPE_CHECKING, Dict from canonicaljson import json @@ -179,6 +179,9 @@ from twisted.internet import defer from synapse.config import ConfigError +if TYPE_CHECKING: + from synapse.server import HomeServer + # Helper class @@ -297,14 +300,11 @@ def _noop_context_manager(*args, **kwargs): # Setup -def init_tracer(config): +def init_tracer(hs: "HomeServer"): """Set the whitelists and initialise the JaegerClient tracer - - Args: - config (HomeserverConfig): The config used by the homeserver """ global opentracing - if not config.opentracer_enabled: + if not hs.config.opentracer_enabled: # We don't have a tracer opentracing = None return @@ -315,18 +315,15 @@ def init_tracer(config): "installed." ) - # Include the worker name - name = config.worker_name if config.worker_name else "master" - # Pull out the jaeger config if it was given. Otherwise set it to something sensible. # See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py - set_homeserver_whitelist(config.opentracer_whitelist) + set_homeserver_whitelist(hs.config.opentracer_whitelist) JaegerConfig( - config=config.jaeger_config, - service_name="{} {}".format(config.server_name, name), - scope_manager=LogContextScopeManager(config), + config=hs.config.jaeger_config, + service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()), + scope_manager=LogContextScopeManager(hs.config), ).initialize_tracer() diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index c7880d4b63..f58e384d17 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -95,7 +95,7 @@ class RdataCommand(Command): Format:: - RDATA + RDATA The `` may either be a numeric stream id OR "batch". The latter case is used to support sending multiple updates with the same stream ID. This @@ -105,33 +105,40 @@ class RdataCommand(Command): The client should batch all incoming RDATA with a token of "batch" (per stream_name) until it sees an RDATA with a numeric stream ID. + The `` is the source of the new data (usually "master"). + `` of "batch" maps to the instance variable `token` being None. An example of a batched series of RDATA:: - RDATA presence batch ["@foo:example.com", "online", ...] - RDATA presence batch ["@bar:example.com", "online", ...] - RDATA presence 59 ["@baz:example.com", "online", ...] + RDATA presence master batch ["@foo:example.com", "online", ...] + RDATA presence master batch ["@bar:example.com", "online", ...] + RDATA presence master 59 ["@baz:example.com", "online", ...] """ NAME = "RDATA" - def __init__(self, stream_name, token, row): + def __init__(self, stream_name, instance_name, token, row): self.stream_name = stream_name + self.instance_name = instance_name self.token = token self.row = row @classmethod def from_line(cls, line): - stream_name, token, row_json = line.split(" ", 2) + stream_name, instance_name, token, row_json = line.split(" ", 3) return cls( - stream_name, None if token == "batch" else int(token), json.loads(row_json) + stream_name, + instance_name, + None if token == "batch" else int(token), + json.loads(row_json), ) def to_line(self): return " ".join( ( self.stream_name, + self.instance_name, str(self.token) if self.token is not None else "batch", _json_encoder.encode(self.row), ) @@ -145,23 +152,31 @@ class PositionCommand(Command): """Sent by the server to tell the client the stream postition without needing to send an RDATA. + Format:: + + POSITION + On receipt of a POSITION command clients should check if they have missed any updates, and if so then fetch them out of band. + + The `` is the process that sent the command and is the source + of the stream. """ NAME = "POSITION" - def __init__(self, stream_name, token): + def __init__(self, stream_name, instance_name, token): self.stream_name = stream_name + self.instance_name = instance_name self.token = token @classmethod def from_line(cls, line): - stream_name, token = line.split(" ", 1) - return cls(stream_name, int(token)) + stream_name, instance_name, token = line.split(" ", 2) + return cls(stream_name, instance_name, int(token)) def to_line(self): - return " ".join((self.stream_name, str(self.token))) + return " ".join((self.stream_name, self.instance_name, str(self.token))) class ErrorCommand(_SimpleCommand): diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b8f49a8d0f..6f7054d5af 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -79,6 +79,7 @@ class ReplicationCommandHandler: self._notifier = hs.get_notifier() self._clock = hs.get_clock() self._instance_id = hs.get_instance_id() + self._instance_name = hs.get_instance_name() # Set of streams that we've caught up with. self._streams_connected = set() # type: Set[str] @@ -156,7 +157,7 @@ class ReplicationCommandHandler: hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, ) else: - client_name = hs.config.worker_name + client_name = hs.get_instance_name() self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port @@ -170,7 +171,9 @@ class ReplicationCommandHandler: for stream_name, stream in self._streams.items(): current_token = stream.current_token() - self.send_command(PositionCommand(stream_name, current_token)) + self.send_command( + PositionCommand(stream_name, self._instance_name, current_token) + ) async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): user_sync_counter.inc() @@ -235,6 +238,10 @@ class ReplicationCommandHandler: await self._server_notices_sender.on_user_ip(cmd.user_id) async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): + if cmd.instance_name == self._instance_name: + # Ignore RDATA that are just our own echoes + return + stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -286,6 +293,10 @@ class ReplicationCommandHandler: await self._replication_data_handler.on_rdata(stream_name, token, rows) async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): + if cmd.instance_name == self._instance_name: + # Ignore POSITION that are just our own echoes + return + stream = self._streams.get(cmd.stream_name) if not stream: logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) @@ -485,7 +496,7 @@ class ReplicationCommandHandler: We need to check if the client is interested in the stream or not """ - self.send_command(RdataCommand(stream_name, token, data)) + self.send_command(RdataCommand(stream_name, self._instance_name, token, data)) UpdateToken = TypeVar("UpdateToken") diff --git a/synapse/server.py b/synapse/server.py index 9d273c980c..bf97a16c09 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -234,7 +234,8 @@ class HomeServer(object): self._listening_services = [] self.start_time = None - self.instance_id = random_string(5) + self._instance_id = random_string(5) + self._instance_name = config.worker_name or "master" self.clock = Clock(reactor) self.distributor = Distributor() @@ -254,7 +255,15 @@ class HomeServer(object): This is used to distinguish running instances in worker-based deployments. """ - return self.instance_id + return self._instance_id + + def get_instance_name(self) -> str: + """A unique name for this synapse process. + + Used to identify the process over replication and in config. Does not + change over restarts. + """ + return self._instance_name def setup(self): logger.info("Setting up.") diff --git a/synapse/server.pyi b/synapse/server.pyi index fc5886f762..18043a2593 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -122,6 +122,8 @@ class HomeServer(object): pass def get_instance_id(self) -> str: pass + def get_instance_name(self) -> str: + pass def get_event_builder_factory(self) -> EventBuilderFactory: pass def get_storage(self) -> synapse.storage.Storage: diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 395c7d0306..1615dfab5e 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -57,6 +57,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): # We now do some gut wrenching so that we have a client that is based # off of the slave store rather than the main store. self.replication_handler = ReplicationCommandHandler(self.hs) + self.replication_handler._instance_name = "worker" self.replication_handler._replication_data_handler = ReplicationDataHandler( self.slaved_store ) diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py index 3cbcb513cc..7ddfd0a733 100644 --- a/tests/replication/tcp/test_commands.py +++ b/tests/replication/tcp/test_commands.py @@ -28,15 +28,17 @@ class ParseCommandTestCase(TestCase): self.assertIsInstance(cmd, ReplicateCommand) def test_parse_rdata(self): - line = 'RDATA events 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' + line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' cmd = parse_command_from_line(line) self.assertIsInstance(cmd, RdataCommand) self.assertEqual(cmd.stream_name, "events") + self.assertEqual(cmd.instance_name, "master") self.assertEqual(cmd.token, 6287863) def test_parse_rdata_batch(self): - line = 'RDATA presence batch ["@foo:example.com", "online"]' + line = 'RDATA presence master batch ["@foo:example.com", "online"]' cmd = parse_command_from_line(line) self.assertIsInstance(cmd, RdataCommand) self.assertEqual(cmd.stream_name, "presence") + self.assertEqual(cmd.instance_name, "master") self.assertIsNone(cmd.token) -- cgit 1.4.1 From 616af44137c78d481024da83bb51ed0d50a49522 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 8 May 2020 14:30:40 +0200 Subject: Implement OpenID Connect-based login (#7256) --- changelog.d/7256.feature | 1 + docs/dev/oidc.md | 175 ++++++ docs/sample_config.yaml | 95 ++++ mypy.ini | 3 + synapse/app/homeserver.py | 12 + synapse/config/_base.pyi | 2 + synapse/config/homeserver.py | 2 + synapse/config/oidc_config.py | 177 ++++++ synapse/config/sso.py | 17 +- synapse/handlers/auth.py | 4 +- synapse/handlers/oidc_handler.py | 998 +++++++++++++++++++++++++++++++++ synapse/http/client.py | 7 + synapse/python_dependencies.py | 1 + synapse/res/templates/sso_error.html | 18 + synapse/rest/client/v1/login.py | 28 +- synapse/rest/oidc/__init__.py | 27 + synapse/rest/oidc/callback_resource.py | 31 + synapse/server.py | 6 + synapse/server.pyi | 5 + tests/handlers/test_oidc.py | 565 +++++++++++++++++++ tox.ini | 1 + 21 files changed, 2163 insertions(+), 12 deletions(-) create mode 100644 changelog.d/7256.feature create mode 100644 docs/dev/oidc.md create mode 100644 synapse/config/oidc_config.py create mode 100644 synapse/handlers/oidc_handler.py create mode 100644 synapse/res/templates/sso_error.html create mode 100644 synapse/rest/oidc/__init__.py create mode 100644 synapse/rest/oidc/callback_resource.py create mode 100644 tests/handlers/test_oidc.py (limited to 'synapse/server.pyi') diff --git a/changelog.d/7256.feature b/changelog.d/7256.feature new file mode 100644 index 0000000000..7ad767bf71 --- /dev/null +++ b/changelog.d/7256.feature @@ -0,0 +1 @@ +Add OpenID Connect login/registration support. Contributed by Quentin Gliech, on behalf of [les Connecteurs](https://connecteu.rs). diff --git a/docs/dev/oidc.md b/docs/dev/oidc.md new file mode 100644 index 0000000000..a90c5d2441 --- /dev/null +++ b/docs/dev/oidc.md @@ -0,0 +1,175 @@ +# How to test OpenID Connect + +Any OpenID Connect Provider (OP) should work with Synapse, as long as it supports the authorization code flow. +There are a few options for that: + + - start a local OP. Synapse has been tested with [Hydra][hydra] and [Dex][dex-idp]. + Note that for an OP to work, it should be served under a secure (HTTPS) origin. + A certificate signed with a self-signed, locally trusted CA should work. In that case, start Synapse with a `SSL_CERT_FILE` environment variable set to the path of the CA. + - use a publicly available OP. Synapse has been tested with [Google][google-idp]. + - setup a SaaS OP, like [Auth0][auth0] and [Okta][okta]. Auth0 has a free tier which has been tested with Synapse. + +[google-idp]: https://developers.google.com/identity/protocols/OpenIDConnect#authenticatingtheuser +[auth0]: https://auth0.com/ +[okta]: https://www.okta.com/ +[dex-idp]: https://github.com/dexidp/dex +[hydra]: https://www.ory.sh/docs/hydra/ + + +## Sample configs + +Here are a few configs for providers that should work with Synapse. + +### [Dex][dex-idp] + +[Dex][dex-idp] is a simple, open-source, certified OpenID Connect Provider. +Although it is designed to help building a full-blown provider, with some external database, it can be configured with static passwords in a config file. + +Follow the [Getting Started guide](https://github.com/dexidp/dex/blob/master/Documentation/getting-started.md) to install Dex. + +Edit `examples/config-dev.yaml` config file from the Dex repo to add a client: + +```yaml +staticClients: +- id: synapse + secret: secret + redirectURIs: + - '[synapse base url]/_synapse/oidc/callback' + name: 'Synapse' +``` + +Run with `dex serve examples/config-dex.yaml` + +Synapse config: + +```yaml +oidc_config: + enabled: true + skip_verification: true # This is needed as Dex is served on an insecure endpoint + issuer: "http://127.0.0.1:5556/dex" + discover: true + client_id: "synapse" + client_secret: "secret" + scopes: + - openid + - profile + user_mapping_provider: + config: + localpart_template: '{{ user.name }}' + display_name_template: '{{ user.name|capitalize }}' +``` + +### [Auth0][auth0] + +1. Create a regular web application for Synapse +2. Set the Allowed Callback URLs to `[synapse base url]/_synapse/oidc/callback` +3. Add a rule to add the `preferred_username` claim. +
+ Code sample + + ```js + function addPersistenceAttribute(user, context, callback) { + user.user_metadata = user.user_metadata || {}; + user.user_metadata.preferred_username = user.user_metadata.preferred_username || user.user_id; + context.idToken.preferred_username = user.user_metadata.preferred_username; + + auth0.users.updateUserMetadata(user.user_id, user.user_metadata) + .then(function(){ + callback(null, user, context); + }) + .catch(function(err){ + callback(err); + }); + } + ``` + +
+ + +```yaml +oidc_config: + enabled: true + issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED + discover: true + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + scopes: + - openid + - profile + user_mapping_provider: + config: + localpart_template: '{{ user.preferred_username }}' + display_name_template: '{{ user.name }}' +``` + +### GitHub + +GitHub is a bit special as it is not an OpenID Connect compliant provider, but just a regular OAuth2 provider. +The `/user` API endpoint can be used to retrieve informations from the user. +As the OIDC login mechanism needs an attribute to uniquely identify users and that endpoint does not return a `sub` property, an alternative `subject_claim` has to be set. + +1. Create a new OAuth application: https://github.com/settings/applications/new +2. Set the callback URL to `[synapse base url]/_synapse/oidc/callback` + +```yaml +oidc_config: + enabled: true + issuer: "https://github.com/" + discover: false + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + authorization_endpoint: "https://github.com/login/oauth/authorize" + token_endpoint: "https://github.com/login/oauth/access_token" + userinfo_endpoint: "https://api.github.com/user" + scopes: + - read:user + user_mapping_provider: + config: + subject_claim: 'id' + localpart_template: '{{ user.login }}' + display_name_template: '{{ user.name }}' +``` + +### Google + +1. Setup a project in the Google API Console +2. Obtain the OAuth 2.0 credentials (see ) +3. Add this Authorized redirect URI: `[synapse base url]/_synapse/oidc/callback` + +```yaml +oidc_config: + enabled: true + issuer: "https://accounts.google.com/" + discover: true + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + scopes: + - openid + - profile + user_mapping_provider: + config: + localpart_template: '{{ user.given_name|lower }}' + display_name_template: '{{ user.name }}' +``` + +### Twitch + +1. Setup a developer account on [Twitch](https://dev.twitch.tv/) +2. Obtain the OAuth 2.0 credentials by [creating an app](https://dev.twitch.tv/console/apps/) +3. Add this OAuth Redirect URL: `[synapse base url]/_synapse/oidc/callback` + +```yaml +oidc_config: + enabled: true + issuer: "https://id.twitch.tv/oauth2/" + discover: true + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + client_auth_method: "client_secret_post" + scopes: + - openid + user_mapping_provider: + config: + localpart_template: '{{ user.preferred_username }}' + display_name_template: '{{ user.name }}' +``` diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 98ead7dc0e..1e397f7734 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1470,6 +1470,94 @@ saml2_config: #template_dir: "res/templates" +# Enable OpenID Connect for registration and login. Uses authlib. +# +oidc_config: + # enable OpenID Connect. Defaults to false. + # + #enabled: true + + # use the OIDC discovery mechanism to discover endpoints. Defaults to true. + # + #discover: true + + # the OIDC issuer. Used to validate tokens and discover the providers endpoints. Required. + # + #issuer: "https://accounts.example.com/" + + # oauth2 client id to use. Required. + # + #client_id: "provided-by-your-issuer" + + # oauth2 client secret to use. Required. + # + #client_secret: "provided-by-your-issuer" + + # auth method to use when exchanging the token. + # Valid values are "client_secret_basic" (default), "client_secret_post" and "none". + # + #client_auth_method: "client_auth_basic" + + # list of scopes to ask. This should include the "openid" scope. Defaults to ["openid"]. + # + #scopes: ["openid"] + + # the oauth2 authorization endpoint. Required if provider discovery is disabled. + # + #authorization_endpoint: "https://accounts.example.com/oauth2/auth" + + # the oauth2 token endpoint. Required if provider discovery is disabled. + # + #token_endpoint: "https://accounts.example.com/oauth2/token" + + # the OIDC userinfo endpoint. Required if discovery is disabled and the "openid" scope is not asked. + # + #userinfo_endpoint: "https://accounts.example.com/userinfo" + + # URI where to fetch the JWKS. Required if discovery is disabled and the "openid" scope is used. + # + #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + + # skip metadata verification. Defaults to false. + # Use this if you are connecting to a provider that is not OpenID Connect compliant. + # Avoid this in production. + # + #skip_verification: false + + + # An external module can be provided here as a custom solution to mapping + # attributes returned from a OIDC provider onto a matrix user. + # + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'. + # + #module: mapping_provider.OidcMappingProvider + + # Custom configuration values for the module. Below options are intended + # for the built-in provider, they should be changed if using a custom + # module. This section will be passed as a Python dictionary to the + # module's `parse_config` method. + # + # Below is the config of the default mapping provider, based on Jinja2 + # templates. Those templates are used to render user attributes, where the + # userinfo object is available through the `user` variable. + # + config: + # name of the claim containing a unique identifier for the user. + # Defaults to `sub`, which OpenID Connect compliant providers should provide. + # + #subject_claim: "sub" + + # Jinja2 template for the localpart of the MXID + # + localpart_template: "{{ user.preferred_username }}" + + # Jinja2 template for the display name to set on first login. Optional. + # + #display_name_template: "{{ user.given_name }} {{ user.last_name }}" + + # Enable CAS for registration and login. # @@ -1554,6 +1642,13 @@ sso: # # This template has no additional variables. # + # * HTML page to display to users if something goes wrong during the + # OpenID Connect authentication process: 'sso_error.html'. + # + # When rendering, this template is given two variables: + # * error: the technical name of the error + # * error_description: a human-readable message for the error + # # You can see the default templates at: # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates # diff --git a/mypy.ini b/mypy.ini index 69be2f67ad..3533797d68 100644 --- a/mypy.ini +++ b/mypy.ini @@ -75,3 +75,6 @@ ignore_missing_imports = True [mypy-jwt.*] ignore_missing_imports = True + +[mypy-authlib.*] +ignore_missing_imports = True diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index cbd1ea475a..bc8695d8dd 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -192,6 +192,11 @@ class SynapseHomeServer(HomeServer): } ) + if self.get_config().oidc_enabled: + from synapse.rest.oidc import OIDCResource + + resources["/_synapse/oidc"] = OIDCResource(self) + if self.get_config().saml2_enabled: from synapse.rest.saml2 import SAML2Resource @@ -422,6 +427,13 @@ def setup(config_options): # Check if it needs to be reprovisioned every day. hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000) + # Load the OIDC provider metadatas, if OIDC is enabled. + if hs.config.oidc_enabled: + oidc = hs.get_oidc_handler() + # Loading the provider metadata also ensures the provider config is valid. + yield defer.ensureDeferred(oidc.load_metadata()) + yield defer.ensureDeferred(oidc.load_jwks()) + _base.start(hs, config.listeners) hs.get_datastore().db.updates.start_doing_background_updates() diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 3053fc9d27..9e576060d4 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -13,6 +13,7 @@ from synapse.config import ( key, logger, metrics, + oidc_config, password, password_auth_providers, push, @@ -59,6 +60,7 @@ class RootConfig: saml2: saml2_config.SAML2Config cas: cas.CasConfig sso: sso.SSOConfig + oidc: oidc_config.OIDCConfig jwt: jwt_config.JWTConfig password: password.PasswordConfig email: emailconfig.EmailConfig diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index be6c6afa74..996d3e6bf7 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -27,6 +27,7 @@ from .jwt_config import JWTConfig from .key import KeyConfig from .logger import LoggingConfig from .metrics import MetricsConfig +from .oidc_config import OIDCConfig from .password import PasswordConfig from .password_auth_providers import PasswordAuthProviderConfig from .push import PushConfig @@ -66,6 +67,7 @@ class HomeServerConfig(RootConfig): AppServiceConfig, KeyConfig, SAML2Config, + OIDCConfig, CasConfig, SSOConfig, JWTConfig, diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py new file mode 100644 index 0000000000..5af110745e --- /dev/null +++ b/synapse/config/oidc_config.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.module_loader import load_module + +from ._base import Config, ConfigError + +DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider" + + +class OIDCConfig(Config): + section = "oidc" + + def read_config(self, config, **kwargs): + self.oidc_enabled = False + + oidc_config = config.get("oidc_config") + + if not oidc_config or not oidc_config.get("enabled", False): + return + + try: + check_requirements("oidc") + except DependencyException as e: + raise ConfigError(e.message) + + public_baseurl = self.public_baseurl + if public_baseurl is None: + raise ConfigError("oidc_config requires a public_baseurl to be set") + self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" + + self.oidc_enabled = True + self.oidc_discover = oidc_config.get("discover", True) + self.oidc_issuer = oidc_config["issuer"] + self.oidc_client_id = oidc_config["client_id"] + self.oidc_client_secret = oidc_config["client_secret"] + self.oidc_client_auth_method = oidc_config.get( + "client_auth_method", "client_secret_basic" + ) + self.oidc_scopes = oidc_config.get("scopes", ["openid"]) + self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint") + self.oidc_token_endpoint = oidc_config.get("token_endpoint") + self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") + self.oidc_jwks_uri = oidc_config.get("jwks_uri") + self.oidc_subject_claim = oidc_config.get("subject_claim", "sub") + self.oidc_skip_verification = oidc_config.get("skip_verification", False) + + ump_config = oidc_config.get("user_mapping_provider", {}) + ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) + ump_config.setdefault("config", {}) + + ( + self.oidc_user_mapping_provider_class, + self.oidc_user_mapping_provider_config, + ) = load_module(ump_config) + + # Ensure loaded user mapping module has defined all necessary methods + required_methods = [ + "get_remote_user_id", + "map_user_attributes", + ] + missing_methods = [ + method + for method in required_methods + if not hasattr(self.oidc_user_mapping_provider_class, method) + ] + if missing_methods: + raise ConfigError( + "Class specified by oidc_config." + "user_mapping_provider.module is missing required " + "methods: %s" % (", ".join(missing_methods),) + ) + + def generate_config_section(self, config_dir_path, server_name, **kwargs): + return """\ + # Enable OpenID Connect for registration and login. Uses authlib. + # + oidc_config: + # enable OpenID Connect. Defaults to false. + # + #enabled: true + + # use the OIDC discovery mechanism to discover endpoints. Defaults to true. + # + #discover: true + + # the OIDC issuer. Used to validate tokens and discover the providers endpoints. Required. + # + #issuer: "https://accounts.example.com/" + + # oauth2 client id to use. Required. + # + #client_id: "provided-by-your-issuer" + + # oauth2 client secret to use. Required. + # + #client_secret: "provided-by-your-issuer" + + # auth method to use when exchanging the token. + # Valid values are "client_secret_basic" (default), "client_secret_post" and "none". + # + #client_auth_method: "client_auth_basic" + + # list of scopes to ask. This should include the "openid" scope. Defaults to ["openid"]. + # + #scopes: ["openid"] + + # the oauth2 authorization endpoint. Required if provider discovery is disabled. + # + #authorization_endpoint: "https://accounts.example.com/oauth2/auth" + + # the oauth2 token endpoint. Required if provider discovery is disabled. + # + #token_endpoint: "https://accounts.example.com/oauth2/token" + + # the OIDC userinfo endpoint. Required if discovery is disabled and the "openid" scope is not asked. + # + #userinfo_endpoint: "https://accounts.example.com/userinfo" + + # URI where to fetch the JWKS. Required if discovery is disabled and the "openid" scope is used. + # + #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + + # skip metadata verification. Defaults to false. + # Use this if you are connecting to a provider that is not OpenID Connect compliant. + # Avoid this in production. + # + #skip_verification: false + + + # An external module can be provided here as a custom solution to mapping + # attributes returned from a OIDC provider onto a matrix user. + # + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # Default is {mapping_provider!r}. + # + #module: mapping_provider.OidcMappingProvider + + # Custom configuration values for the module. Below options are intended + # for the built-in provider, they should be changed if using a custom + # module. This section will be passed as a Python dictionary to the + # module's `parse_config` method. + # + # Below is the config of the default mapping provider, based on Jinja2 + # templates. Those templates are used to render user attributes, where the + # userinfo object is available through the `user` variable. + # + config: + # name of the claim containing a unique identifier for the user. + # Defaults to `sub`, which OpenID Connect compliant providers should provide. + # + #subject_claim: "sub" + + # Jinja2 template for the localpart of the MXID + # + localpart_template: "{{{{ user.preferred_username }}}}" + + # Jinja2 template for the display name to set on first login. Optional. + # + #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}" + """.format( + mapping_provider=DEFAULT_USER_MAPPING_PROVIDER + ) diff --git a/synapse/config/sso.py b/synapse/config/sso.py index cac6bc0139..aff642f015 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -36,17 +36,13 @@ class SSOConfig(Config): if not template_dir: template_dir = pkg_resources.resource_filename("synapse", "res/templates",) - self.sso_redirect_confirm_template_dir = template_dir + self.sso_template_dir = template_dir self.sso_account_deactivated_template = self.read_file( - os.path.join( - self.sso_redirect_confirm_template_dir, "sso_account_deactivated.html" - ), + os.path.join(self.sso_template_dir, "sso_account_deactivated.html"), "sso_account_deactivated_template", ) self.sso_auth_success_template = self.read_file( - os.path.join( - self.sso_redirect_confirm_template_dir, "sso_auth_success.html" - ), + os.path.join(self.sso_template_dir, "sso_auth_success.html"), "sso_auth_success_template", ) @@ -137,6 +133,13 @@ class SSOConfig(Config): # # This template has no additional variables. # + # * HTML page to display to users if something goes wrong during the + # OpenID Connect authentication process: 'sso_error.html'. + # + # When rendering, this template is given two variables: + # * error: the technical name of the error + # * error_description: a human-readable message for the error + # # You can see the default templates at: # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates # diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7613e5b6ab..f8d2331bf1 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -126,13 +126,13 @@ class AuthHandler(BaseHandler): # It notifies the user they are about to give access to their matrix account # to the client. self._sso_redirect_confirm_template = load_jinja2_templates( - hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"], + hs.config.sso_template_dir, ["sso_redirect_confirm.html"], )[0] # The following template is shown during user interactive authentication # in the fallback auth scenario. It notifies the user that they are # authenticating for an operation to occur on their account. self._sso_auth_confirm_template = load_jinja2_templates( - hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"], + hs.config.sso_template_dir, ["sso_auth_confirm.html"], )[0] # The following template is shown after a successful user interactive # authentication session. It tells the user they can close the window. diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py new file mode 100644 index 0000000000..178f263439 --- /dev/null +++ b/synapse/handlers/oidc_handler.py @@ -0,0 +1,998 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +from typing import Dict, Generic, List, Optional, Tuple, TypeVar +from urllib.parse import urlencode + +import attr +import pymacaroons +from authlib.common.security import generate_token +from authlib.jose import JsonWebToken +from authlib.oauth2.auth import ClientAuth +from authlib.oauth2.rfc6749.parameters import prepare_grant_uri +from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo +from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url +from jinja2 import Environment, Template +from pymacaroons.exceptions import ( + MacaroonDeserializationException, + MacaroonInvalidSignatureException, +) +from typing_extensions import TypedDict + +from twisted.web.client import readBody + +from synapse.config import ConfigError +from synapse.http.server import finish_request +from synapse.http.site import SynapseRequest +from synapse.push.mailer import load_jinja2_templates +from synapse.server import HomeServer +from synapse.types import UserID, map_username_to_mxid_localpart + +logger = logging.getLogger(__name__) + +SESSION_COOKIE_NAME = b"oidc_session" + +#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and +#: OpenID.Core sec 3.1.3.3. +Token = TypedDict( + "Token", + { + "access_token": str, + "token_type": str, + "id_token": Optional[str], + "refresh_token": Optional[str], + "expires_in": int, + "scope": Optional[str], + }, +) + +#: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but +#: there is no real point of doing this in our case. +JWK = Dict[str, str] + +#: A JWK Set, as per RFC7517 sec 5. +JWKS = TypedDict("JWKS", {"keys": List[JWK]}) + + +class OidcError(Exception): + """Used to catch errors when calling the token_endpoint + """ + + def __init__(self, error, error_description=None): + self.error = error + self.error_description = error_description + + def __str__(self): + if self.error_description: + return "{}: {}".format(self.error, self.error_description) + return self.error + + +class MappingException(Exception): + """Used to catch errors when mapping the UserInfo object + """ + + +class OidcHandler: + """Handles requests related to the OpenID Connect login flow. + """ + + def __init__(self, hs: HomeServer): + self._callback_url = hs.config.oidc_callback_url # type: str + self._scopes = hs.config.oidc_scopes # type: List[str] + self._client_auth = ClientAuth( + hs.config.oidc_client_id, + hs.config.oidc_client_secret, + hs.config.oidc_client_auth_method, + ) # type: ClientAuth + self._client_auth_method = hs.config.oidc_client_auth_method # type: str + self._subject_claim = hs.config.oidc_subject_claim + self._provider_metadata = OpenIDProviderMetadata( + issuer=hs.config.oidc_issuer, + authorization_endpoint=hs.config.oidc_authorization_endpoint, + token_endpoint=hs.config.oidc_token_endpoint, + userinfo_endpoint=hs.config.oidc_userinfo_endpoint, + jwks_uri=hs.config.oidc_jwks_uri, + ) # type: OpenIDProviderMetadata + self._provider_needs_discovery = hs.config.oidc_discover # type: bool + self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class( + hs.config.oidc_user_mapping_provider_config + ) # type: OidcMappingProvider + self._skip_verification = hs.config.oidc_skip_verification # type: bool + + self._http_client = hs.get_proxied_http_client() + self._auth_handler = hs.get_auth_handler() + self._registration_handler = hs.get_registration_handler() + self._datastore = hs.get_datastore() + self._clock = hs.get_clock() + self._hostname = hs.hostname # type: str + self._server_name = hs.config.server_name # type: str + self._macaroon_secret_key = hs.config.macaroon_secret_key + self._error_template = load_jinja2_templates( + hs.config.sso_template_dir, ["sso_error.html"] + )[0] + + # identifier for the external_ids table + self._auth_provider_id = "oidc" + + def _render_error( + self, request, error: str, error_description: Optional[str] = None + ) -> None: + """Renders the error template and respond with it. + + This is used to show errors to the user. The template of this page can + be found under ``synapse/res/templates/sso_error.html``. + + Args: + request: The incoming request from the browser. + We'll respond with an HTML page describing the error. + error: A technical identifier for this error. Those include + well-known OAuth2/OIDC error types like invalid_request or + access_denied. + error_description: A human-readable description of the error. + """ + html_bytes = self._error_template.render( + error=error, error_description=error_description + ).encode("utf-8") + + request.setResponseCode(400) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%i" % len(html_bytes)) + request.write(html_bytes) + finish_request(request) + + def _validate_metadata(self): + """Verifies the provider metadata. + + This checks the validity of the currently loaded provider. Not + everything is checked, only: + + - ``issuer`` + - ``authorization_endpoint`` + - ``token_endpoint`` + - ``response_types_supported`` (checks if "code" is in it) + - ``jwks_uri`` + + Raises: + ValueError: if something in the provider is not valid + """ + # Skip verification to allow non-compliant providers (e.g. issuers not running on a secure origin) + if self._skip_verification is True: + return + + m = self._provider_metadata + m.validate_issuer() + m.validate_authorization_endpoint() + m.validate_token_endpoint() + + if m.get("token_endpoint_auth_methods_supported") is not None: + m.validate_token_endpoint_auth_methods_supported() + if ( + self._client_auth_method + not in m["token_endpoint_auth_methods_supported"] + ): + raise ValueError( + '"{auth_method}" not in "token_endpoint_auth_methods_supported" ({supported!r})'.format( + auth_method=self._client_auth_method, + supported=m["token_endpoint_auth_methods_supported"], + ) + ) + + if m.get("response_types_supported") is not None: + m.validate_response_types_supported() + + if "code" not in m["response_types_supported"]: + raise ValueError( + '"code" not in "response_types_supported" (%r)' + % (m["response_types_supported"],) + ) + + # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos + if self._uses_userinfo: + if m.get("userinfo_endpoint") is None: + raise ValueError( + 'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested' + ) + else: + # If we're not using userinfo, we need a valid jwks to validate the ID token + if m.get("jwks") is None: + if m.get("jwks_uri") is not None: + m.validate_jwks_uri() + else: + raise ValueError('"jwks_uri" must be set') + + @property + def _uses_userinfo(self) -> bool: + """Returns True if the ``userinfo_endpoint`` should be used. + + This is based on the requested scopes: if the scopes include + ``openid``, the provider should give use an ID token containing the + user informations. If not, we should fetch them using the + ``access_token`` with the ``userinfo_endpoint``. + """ + + # Maybe that should be user-configurable and not inferred? + return "openid" not in self._scopes + + async def load_metadata(self) -> OpenIDProviderMetadata: + """Load and validate the provider metadata. + + The values metadatas are discovered if ``oidc_config.discovery`` is + ``True`` and then cached. + + Raises: + ValueError: if something in the provider is not valid + + Returns: + The provider's metadata. + """ + # If we are using the OpenID Discovery documents, it needs to be loaded once + # FIXME: should there be a lock here? + if self._provider_needs_discovery: + url = get_well_known_url(self._provider_metadata["issuer"], external=True) + metadata_response = await self._http_client.get_json(url) + # TODO: maybe update the other way around to let user override some values? + self._provider_metadata.update(metadata_response) + self._provider_needs_discovery = False + + self._validate_metadata() + + return self._provider_metadata + + async def load_jwks(self, force: bool = False) -> JWKS: + """Load the JSON Web Key Set used to sign ID tokens. + + If we're not using the ``userinfo_endpoint``, user infos are extracted + from the ID token, which is a JWT signed by keys given by the provider. + The keys are then cached. + + Args: + force: Force reloading the keys. + + Returns: + The key set + + Looks like this:: + + { + 'keys': [ + { + 'kid': 'abcdef', + 'kty': 'RSA', + 'alg': 'RS256', + 'use': 'sig', + 'e': 'XXXX', + 'n': 'XXXX', + } + ] + } + """ + if self._uses_userinfo: + # We're not using jwt signing, return an empty jwk set + return {"keys": []} + + # First check if the JWKS are loaded in the provider metadata. + # It can happen either if the provider gives its JWKS in the discovery + # document directly or if it was already loaded once. + metadata = await self.load_metadata() + jwk_set = metadata.get("jwks") + if jwk_set is not None and not force: + return jwk_set + + # Loading the JWKS using the `jwks_uri` metadata + uri = metadata.get("jwks_uri") + if not uri: + raise RuntimeError('Missing "jwks_uri" in metadata') + + jwk_set = await self._http_client.get_json(uri) + + # Caching the JWKS in the provider's metadata + self._provider_metadata["jwks"] = jwk_set + return jwk_set + + async def _exchange_code(self, code: str) -> Token: + """Exchange an authorization code for a token. + + This calls the ``token_endpoint`` with the authorization code we + received in the callback to exchange it for a token. The call uses the + ``ClientAuth`` to authenticate with the client with its ID and secret. + + Args: + code: The autorization code we got from the callback. + + Returns: + A dict containing various tokens. + + May look like this:: + + { + 'token_type': 'bearer', + 'access_token': 'abcdef', + 'expires_in': 3599, + 'id_token': 'ghijkl', + 'refresh_token': 'mnopqr', + } + + Raises: + OidcError: when the ``token_endpoint`` returned an error. + """ + metadata = await self.load_metadata() + token_endpoint = metadata.get("token_endpoint") + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": self._http_client.user_agent, + "Accept": "application/json", + } + + args = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self._callback_url, + } + body = urlencode(args, True) + + # Fill the body/headers with credentials + uri, headers, body = self._client_auth.prepare( + method="POST", uri=token_endpoint, headers=headers, body=body + ) + headers = {k: [v] for (k, v) in headers.items()} + + # Do the actual request + # We're not using the SimpleHttpClient util methods as we don't want to + # check the HTTP status code and we do the body encoding ourself. + response = await self._http_client.request( + method="POST", uri=uri, data=body.encode("utf-8"), headers=headers, + ) + + # This is used in multiple error messages below + status = "{code} {phrase}".format( + code=response.code, phrase=response.phrase.decode("utf-8") + ) + + resp_body = await readBody(response) + + if response.code >= 500: + # In case of a server error, we should first try to decode the body + # and check for an error field. If not, we respond with a generic + # error message. + try: + resp = json.loads(resp_body.decode("utf-8")) + error = resp["error"] + description = resp.get("error_description", error) + except (ValueError, KeyError): + # Catch ValueError for the JSON decoding and KeyError for the "error" field + error = "server_error" + description = ( + ( + 'Authorization server responded with a "{status}" error ' + "while exchanging the authorization code." + ).format(status=status), + ) + + raise OidcError(error, description) + + # Since it is a not a 5xx code, body should be a valid JSON. It will + # raise if not. + resp = json.loads(resp_body.decode("utf-8")) + + if "error" in resp: + error = resp["error"] + # In case the authorization server responded with an error field, + # it should be a 4xx code. If not, warn about it but don't do + # anything special and report the original error message. + if response.code < 400: + logger.debug( + "Invalid response from the authorization server: " + 'responded with a "{status}" ' + "but body has an error field: {error!r}".format( + status=status, error=resp["error"] + ) + ) + + description = resp.get("error_description", error) + raise OidcError(error, description) + + # Now, this should not be an error. According to RFC6749 sec 5.1, it + # should be a 200 code. We're a bit more flexible than that, and will + # only throw on a 4xx code. + if response.code >= 400: + description = ( + 'Authorization server responded with a "{status}" error ' + 'but did not include an "error" field in its response.'.format( + status=status + ) + ) + logger.warning(description) + # Body was still valid JSON. Might be useful to log it for debugging. + logger.warning("Code exchange response: {resp!r}".format(resp=resp)) + raise OidcError("server_error", description) + + return resp + + async def _fetch_userinfo(self, token: Token) -> UserInfo: + """Fetch user informations from the ``userinfo_endpoint``. + + Args: + token: the token given by the ``token_endpoint``. + Must include an ``access_token`` field. + + Returns: + UserInfo: an object representing the user. + """ + metadata = await self.load_metadata() + + resp = await self._http_client.get_json( + metadata["userinfo_endpoint"], + headers={"Authorization": ["Bearer {}".format(token["access_token"])]}, + ) + + return UserInfo(resp) + + async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo: + """Return an instance of UserInfo from token's ``id_token``. + + Args: + token: the token given by the ``token_endpoint``. + Must include an ``id_token`` field. + nonce: the nonce value originally sent in the initial authorization + request. This value should match the one inside the token. + + Returns: + An object representing the user. + """ + metadata = await self.load_metadata() + claims_params = { + "nonce": nonce, + "client_id": self._client_auth.client_id, + } + if "access_token" in token: + # If we got an `access_token`, there should be an `at_hash` claim + # in the `id_token` that we can check against. + claims_params["access_token"] = token["access_token"] + claims_cls = CodeIDToken + else: + claims_cls = ImplicitIDToken + + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) + + jwt = JsonWebToken(alg_values) + + claim_options = {"iss": {"values": [metadata["issuer"]]}} + + # Try to decode the keys in cache first, then retry by forcing the keys + # to be reloaded + jwk_set = await self.load_jwks() + try: + claims = jwt.decode( + token["id_token"], + key=jwk_set, + claims_cls=claims_cls, + claims_options=claim_options, + claims_params=claims_params, + ) + except ValueError: + jwk_set = await self.load_jwks(force=True) # try reloading the jwks + claims = jwt.decode( + token["id_token"], + key=jwk_set, + claims_cls=claims_cls, + claims_options=claim_options, + claims_params=claims_params, + ) + + claims.validate(leeway=120) # allows 2 min of clock skew + return UserInfo(claims) + + async def handle_redirect_request( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> None: + """Handle an incoming request to /login/sso/redirect + + It redirects the browser to the authorization endpoint with a few + parameters: + + - ``client_id``: the client ID set in ``oidc_config.client_id`` + - ``response_type``: ``code`` + - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback`` + - ``scope``: the list of scopes set in ``oidc_config.scopes`` + - ``state``: a random string + - ``nonce``: a random string + + In addition to redirecting the client, we are setting a cookie with + a signed macaroon token containing the state, the nonce and the + client_redirect_url params. Those are then checked when the client + comes back from the provider. + + + Args: + request: the incoming request from the browser. + We'll respond to it with a redirect and a cookie. + client_redirect_url: the URL that we should redirect the client to + when everything is done + """ + + state = generate_token() + nonce = generate_token() + + cookie = self._generate_oidc_session_token( + state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(), + ) + request.addCookie( + SESSION_COOKIE_NAME, + cookie, + path="/_synapse/oidc", + max_age="3600", + httpOnly=True, + sameSite="lax", + ) + + metadata = await self.load_metadata() + authorization_endpoint = metadata.get("authorization_endpoint") + uri = prepare_grant_uri( + authorization_endpoint, + client_id=self._client_auth.client_id, + response_type="code", + redirect_uri=self._callback_url, + scope=self._scopes, + state=state, + nonce=nonce, + ) + request.redirect(uri) + finish_request(request) + + async def handle_oidc_callback(self, request: SynapseRequest) -> None: + """Handle an incoming request to /_synapse/oidc/callback + + Since we might want to display OIDC-related errors in a user-friendly + way, we don't raise SynapseError from here. Instead, we call + ``self._render_error`` which displays an HTML page for the error. + + Most of the OpenID Connect logic happens here: + + - first, we check if there was any error returned by the provider and + display it + - then we fetch the session cookie, decode and verify it + - the ``state`` query parameter should match with the one stored in the + session cookie + - once we known this session is legit, exchange the code with the + provider using the ``token_endpoint`` (see ``_exchange_code``) + - once we have the token, use it to either extract the UserInfo from + the ``id_token`` (``_parse_id_token``), or use the ``access_token`` + to fetch UserInfo from the ``userinfo_endpoint`` + (``_fetch_userinfo``) + - map those UserInfo to a Matrix user (``_map_userinfo_to_user``) and + finish the login + + Args: + request: the incoming request from the browser. + """ + + # The provider might redirect with an error. + # In that case, just display it as-is. + if b"error" in request.args: + error = request.args[b"error"][0].decode() + description = request.args.get(b"error_description", [b""])[0].decode() + + # Most of the errors returned by the provider could be due by + # either the provider misbehaving or Synapse being misconfigured. + # The only exception of that is "access_denied", where the user + # probably cancelled the login flow. In other cases, log those errors. + if error != "access_denied": + logger.error("Error from the OIDC provider: %s %s", error, description) + + self._render_error(request, error, description) + return + + # Fetch the session cookie + session = request.getCookie(SESSION_COOKIE_NAME) + if session is None: + logger.info("No session cookie found") + self._render_error(request, "missing_session", "No session cookie found") + return + + # Remove the cookie. There is a good chance that if the callback failed + # once, it will fail next time and the code will already be exchanged. + # Removing it early avoids spamming the provider with token requests. + request.addCookie( + SESSION_COOKIE_NAME, + b"", + path="/_synapse/oidc", + expires="Thu, Jan 01 1970 00:00:00 UTC", + httpOnly=True, + sameSite="lax", + ) + + # Check for the state query parameter + if b"state" not in request.args: + logger.info("State parameter is missing") + self._render_error(request, "invalid_request", "State parameter is missing") + return + + state = request.args[b"state"][0].decode() + + # Deserialize the session token and verify it. + try: + nonce, client_redirect_url = self._verify_oidc_session_token(session, state) + except MacaroonDeserializationException as e: + logger.exception("Invalid session") + self._render_error(request, "invalid_session", str(e)) + return + except MacaroonInvalidSignatureException as e: + logger.exception("Could not verify session") + self._render_error(request, "mismatching_session", str(e)) + return + + # Exchange the code with the provider + if b"code" not in request.args: + logger.info("Code parameter is missing") + self._render_error(request, "invalid_request", "Code parameter is missing") + return + + logger.info("Exchanging code") + code = request.args[b"code"][0].decode() + try: + token = await self._exchange_code(code) + except OidcError as e: + logger.exception("Could not exchange code") + self._render_error(request, e.error, e.error_description) + return + + # Now that we have a token, get the userinfo, either by decoding the + # `id_token` or by fetching the `userinfo_endpoint`. + if self._uses_userinfo: + logger.info("Fetching userinfo") + try: + userinfo = await self._fetch_userinfo(token) + except Exception as e: + logger.exception("Could not fetch userinfo") + self._render_error(request, "fetch_error", str(e)) + return + else: + logger.info("Extracting userinfo from id_token") + try: + userinfo = await self._parse_id_token(token, nonce=nonce) + except Exception as e: + logger.exception("Invalid id_token") + self._render_error(request, "invalid_token", str(e)) + return + + # Call the mapper to register/login the user + try: + user_id = await self._map_userinfo_to_user(userinfo, token) + except MappingException as e: + logger.exception("Could not map user") + self._render_error(request, "mapping_error", str(e)) + return + + # and finally complete the login + await self._auth_handler.complete_sso_login( + user_id, request, client_redirect_url + ) + + def _generate_oidc_session_token( + self, + state: str, + nonce: str, + client_redirect_url: str, + duration_in_ms: int = (60 * 60 * 1000), + ) -> str: + """Generates a signed token storing data about an OIDC session. + + When Synapse initiates an authorization flow, it creates a random state + and a random nonce. Those parameters are given to the provider and + should be verified when the client comes back from the provider. + It is also used to store the client_redirect_url, which is used to + complete the SSO login flow. + + Args: + state: The ``state`` parameter passed to the OIDC provider. + nonce: The ``nonce`` parameter passed to the OIDC provider. + client_redirect_url: The URL the client gave when it initiated the + flow. + duration_in_ms: An optional duration for the token in milliseconds. + Defaults to an hour. + + Returns: + A signed macaroon token with the session informations. + """ + macaroon = pymacaroons.Macaroon( + location=self._server_name, identifier="key", key=self._macaroon_secret_key, + ) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = session") + macaroon.add_first_party_caveat("state = %s" % (state,)) + macaroon.add_first_party_caveat("nonce = %s" % (nonce,)) + macaroon.add_first_party_caveat( + "client_redirect_url = %s" % (client_redirect_url,) + ) + now = self._clock.time_msec() + expiry = now + duration_in_ms + macaroon.add_first_party_caveat("time < %d" % (expiry,)) + return macaroon.serialize() + + def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]: + """Verifies and extract an OIDC session token. + + This verifies that a given session token was issued by this homeserver + and extract the nonce and client_redirect_url caveats. + + Args: + session: The session token to verify + state: The state the OIDC provider gave back + + Returns: + The nonce and the client_redirect_url for this session + """ + macaroon = pymacaroons.Macaroon.deserialize(session) + + v = pymacaroons.Verifier() + v.satisfy_exact("gen = 1") + v.satisfy_exact("type = session") + v.satisfy_exact("state = %s" % (state,)) + v.satisfy_general(lambda c: c.startswith("nonce = ")) + v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) + v.satisfy_general(self._verify_expiry) + + v.verify(macaroon, self._macaroon_secret_key) + + # Extract the `nonce` and `client_redirect_url` from the token + nonce = self._get_value_from_macaroon(macaroon, "nonce") + client_redirect_url = self._get_value_from_macaroon( + macaroon, "client_redirect_url" + ) + + return nonce, client_redirect_url + + def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: + """Extracts a caveat value from a macaroon token. + + Args: + macaroon: the token + key: the key of the caveat to extract + + Returns: + The extracted value + + Raises: + Exception: if the caveat was not in the macaroon + """ + prefix = key + " = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(prefix): + return caveat.caveat_id[len(prefix) :] + raise Exception("No %s caveat in macaroon" % (key,)) + + def _verify_expiry(self, caveat: str) -> bool: + prefix = "time < " + if not caveat.startswith(prefix): + return False + expiry = int(caveat[len(prefix) :]) + now = self._clock.time_msec() + return now < expiry + + async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str: + """Maps a UserInfo object to a mxid. + + UserInfo should have a claim that uniquely identifies users. This claim + is usually `sub`, but can be configured with `oidc_config.subject_claim`. + It is then used as an `external_id`. + + If we don't find the user that way, we should register the user, + mapping the localpart and the display name from the UserInfo. + + If a user already exists with the mxid we've mapped, raise an exception. + + Args: + userinfo: an object representing the user + token: a dict with the tokens obtained from the provider + + Raises: + MappingException: if there was an error while mapping some properties + + Returns: + The mxid of the user + """ + try: + remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo) + except Exception as e: + raise MappingException( + "Failed to extract subject from OIDC response: %s" % (e,) + ) + + logger.info( + "Looking for existing mapping for user %s:%s", + self._auth_provider_id, + remote_user_id, + ) + + registered_user_id = await self._datastore.get_user_by_external_id( + self._auth_provider_id, remote_user_id, + ) + + if registered_user_id is not None: + logger.info("Found existing mapping %s", registered_user_id) + return registered_user_id + + try: + attributes = await self._user_mapping_provider.map_user_attributes( + userinfo, token + ) + except Exception as e: + raise MappingException( + "Could not extract user attributes from OIDC response: " + str(e) + ) + + logger.debug( + "Retrieved user attributes from user mapping provider: %r", attributes + ) + + if not attributes["localpart"]: + raise MappingException("localpart is empty") + + localpart = map_username_to_mxid_localpart(attributes["localpart"]) + + user_id = UserID(localpart, self._hostname) + if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()): + # This mxid is taken + raise MappingException( + "mxid '{}' is already taken".format(user_id.to_string()) + ) + + # It's the first time this user is logging in and the mapped mxid was + # not taken, register the user + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, default_display_name=attributes["display_name"], + ) + + await self._datastore.record_user_external_id( + self._auth_provider_id, remote_user_id, registered_user_id, + ) + return registered_user_id + + +UserAttribute = TypedDict( + "UserAttribute", {"localpart": str, "display_name": Optional[str]} +) +C = TypeVar("C") + + +class OidcMappingProvider(Generic[C]): + """A mapping provider maps a UserInfo object to user attributes. + + It should provide the API described by this class. + """ + + def __init__(self, config: C): + """ + Args: + config: A custom config object from this module, parsed by ``parse_config()`` + """ + + @staticmethod + def parse_config(config: dict) -> C: + """Parse the dict provided by the homeserver's config + + Args: + config: A dictionary containing configuration options for this provider + + Returns: + A custom config object for this module + """ + raise NotImplementedError() + + def get_remote_user_id(self, userinfo: UserInfo) -> str: + """Get a unique user ID for this user. + + Usually, in an OIDC-compliant scenario, it should be the ``sub`` claim from the UserInfo object. + + Args: + userinfo: An object representing the user given by the OIDC provider + + Returns: + A unique user ID + """ + raise NotImplementedError() + + async def map_user_attributes( + self, userinfo: UserInfo, token: Token + ) -> UserAttribute: + """Map a ``UserInfo`` objects into user attributes. + + Args: + userinfo: An object representing the user given by the OIDC provider + token: A dict with the tokens returned by the provider + + Returns: + A dict containing the ``localpart`` and (optionally) the ``display_name`` + """ + raise NotImplementedError() + + +# Used to clear out "None" values in templates +def jinja_finalize(thing): + return thing if thing is not None else "" + + +env = Environment(finalize=jinja_finalize) + + +@attr.s +class JinjaOidcMappingConfig: + subject_claim = attr.ib() # type: str + localpart_template = attr.ib() # type: Template + display_name_template = attr.ib() # type: Optional[Template] + + +class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): + """An implementation of a mapping provider based on Jinja templates. + + This is the default mapping provider. + """ + + def __init__(self, config: JinjaOidcMappingConfig): + self._config = config + + @staticmethod + def parse_config(config: dict) -> JinjaOidcMappingConfig: + subject_claim = config.get("subject_claim", "sub") + + if "localpart_template" not in config: + raise ConfigError( + "missing key: oidc_config.user_mapping_provider.config.localpart_template" + ) + + try: + localpart_template = env.from_string(config["localpart_template"]) + except Exception as e: + raise ConfigError( + "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r" + % (e,) + ) + + display_name_template = None # type: Optional[Template] + if "display_name_template" in config: + try: + display_name_template = env.from_string(config["display_name_template"]) + except Exception as e: + raise ConfigError( + "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r" + % (e,) + ) + + return JinjaOidcMappingConfig( + subject_claim=subject_claim, + localpart_template=localpart_template, + display_name_template=display_name_template, + ) + + def get_remote_user_id(self, userinfo: UserInfo) -> str: + return userinfo[self._config.subject_claim] + + async def map_user_attributes( + self, userinfo: UserInfo, token: Token + ) -> UserAttribute: + localpart = self._config.localpart_template.render(user=userinfo).strip() + + display_name = None # type: Optional[str] + if self._config.display_name_template is not None: + display_name = self._config.display_name_template.render( + user=userinfo + ).strip() + + if display_name == "": + display_name = None + + return UserAttribute(localpart=localpart, display_name=display_name) diff --git a/synapse/http/client.py b/synapse/http/client.py index 3797545824..58eb47c69c 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -359,6 +359,7 @@ class SimpleHttpClient(object): actual_headers = { b"Content-Type": [b"application/x-www-form-urlencoded"], b"User-Agent": [self.user_agent], + b"Accept": [b"application/json"], } if headers: actual_headers.update(headers) @@ -399,6 +400,7 @@ class SimpleHttpClient(object): actual_headers = { b"Content-Type": [b"application/json"], b"User-Agent": [self.user_agent], + b"Accept": [b"application/json"], } if headers: actual_headers.update(headers) @@ -434,6 +436,10 @@ class SimpleHttpClient(object): ValueError: if the response was not JSON """ + actual_headers = {b"Accept": [b"application/json"]} + if headers: + actual_headers.update(headers) + body = yield self.get_raw(uri, args, headers=headers) return json.loads(body) @@ -467,6 +473,7 @@ class SimpleHttpClient(object): actual_headers = { b"Content-Type": [b"application/json"], b"User-Agent": [self.user_agent], + b"Accept": [b"application/json"], } if headers: actual_headers.update(headers) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 39c99a2802..8b4312e5a3 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -92,6 +92,7 @@ CONDITIONAL_REQUIREMENTS = { 'eliot<1.8.0;python_version<"3.5.3"', ], "saml2": ["pysaml2>=4.5.0"], + "oidc": ["authlib>=0.14.0"], "systemd": ["systemd-python>=231"], "url_preview": ["lxml>=3.5.0"], "test": ["mock>=2.0", "parameterized"], diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html new file mode 100644 index 0000000000..43a211386b --- /dev/null +++ b/synapse/res/templates/sso_error.html @@ -0,0 +1,18 @@ + + + + + SSO error + + +

Oops! Something went wrong during authentication.

+

+ Try logging in again from your Matrix client and if the problem persists + please contact the server's administrator. +

+

Error: {{ error }}

+ {% if error_description %} +
{{ error_description }}
+ {% endif %} + + diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 4de2f97d06..de7eca21f8 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -83,6 +83,7 @@ class LoginRestServlet(RestServlet): self.jwt_algorithm = hs.config.jwt_algorithm self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled + self.oidc_enabled = hs.config.oidc_enabled self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() @@ -96,9 +97,7 @@ class LoginRestServlet(RestServlet): flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) - if self.saml2_enabled: - flows.append({"type": LoginRestServlet.SSO_TYPE}) - flows.append({"type": LoginRestServlet.TOKEN_TYPE}) + if self.cas_enabled: flows.append({"type": LoginRestServlet.SSO_TYPE}) @@ -114,6 +113,11 @@ class LoginRestServlet(RestServlet): # fall back to the fallback API if they don't understand one of the # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) + elif self.saml2_enabled: + flows.append({"type": LoginRestServlet.SSO_TYPE}) + flows.append({"type": LoginRestServlet.TOKEN_TYPE}) + elif self.oidc_enabled: + flows.append({"type": LoginRestServlet.SSO_TYPE}) flows.extend( ({"type": t} for t in self.auth_handler.get_supported_login_types()) @@ -465,6 +469,22 @@ class SAMLRedirectServlet(BaseSSORedirectServlet): return self._saml_handler.handle_redirect_request(client_redirect_url) +class OIDCRedirectServlet(RestServlet): + """Implementation for /login/sso/redirect for the OIDC login flow.""" + + PATTERNS = client_patterns("/login/sso/redirect", v1=True) + + def __init__(self, hs): + self._oidc_handler = hs.get_oidc_handler() + + async def on_GET(self, request): + args = request.args + if b"redirectUrl" not in args: + return 400, "Redirect URL not specified for SSO auth" + client_redirect_url = args[b"redirectUrl"][0] + await self._oidc_handler.handle_redirect_request(request, client_redirect_url) + + def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) if hs.config.cas_enabled: @@ -472,3 +492,5 @@ def register_servlets(hs, http_server): CasTicketServlet(hs).register(http_server) elif hs.config.saml2_enabled: SAMLRedirectServlet(hs).register(http_server) + elif hs.config.oidc_enabled: + OIDCRedirectServlet(hs).register(http_server) diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/oidc/__init__.py new file mode 100644 index 0000000000..d958dd65bb --- /dev/null +++ b/synapse/rest/oidc/__init__.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.web.resource import Resource + +from synapse.rest.oidc.callback_resource import OIDCCallbackResource + +logger = logging.getLogger(__name__) + + +class OIDCResource(Resource): + def __init__(self, hs): + Resource.__init__(self) + self.putChild(b"callback", OIDCCallbackResource(hs)) diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py new file mode 100644 index 0000000000..c03194f001 --- /dev/null +++ b/synapse/rest/oidc/callback_resource.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.http.server import DirectServeResource, wrap_html_request_handler + +logger = logging.getLogger(__name__) + + +class OIDCCallbackResource(DirectServeResource): + isLeaf = 1 + + def __init__(self, hs): + super().__init__() + self._oidc_handler = hs.get_oidc_handler() + + @wrap_html_request_handler + async def _async_render_GET(self, request): + return await self._oidc_handler.handle_oidc_callback(request) diff --git a/synapse/server.py b/synapse/server.py index bf97a16c09..b4aea81e24 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -204,6 +204,7 @@ class HomeServer(object): "account_validity_handler", "cas_handler", "saml_handler", + "oidc_handler", "event_client_serializer", "password_policy_handler", "storage", @@ -562,6 +563,11 @@ class HomeServer(object): return SamlHandler(self) + def build_oidc_handler(self): + from synapse.handlers.oidc_handler import OidcHandler + + return OidcHandler(self) + def build_event_client_serializer(self): return EventClientSerializer(self) diff --git a/synapse/server.pyi b/synapse/server.pyi index 18043a2593..31a9cc0389 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -13,6 +13,7 @@ import synapse.handlers.device import synapse.handlers.e2e_keys import synapse.handlers.message import synapse.handlers.presence +import synapse.handlers.register import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password @@ -128,3 +129,7 @@ class HomeServer(object): pass def get_storage(self) -> synapse.storage.Storage: pass + def get_registration_handler(self) -> synapse.handlers.register.RegistrationHandler: + pass + def get_macaroon_generator(self) -> synapse.handlers.auth.MacaroonGenerator: + pass diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py new file mode 100644 index 0000000000..61963aa90d --- /dev/null +++ b/tests/handlers/test_oidc.py @@ -0,0 +1,565 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from urllib.parse import parse_qs, urlparse + +from mock import Mock, patch + +import attr +import pymacaroons + +from twisted.internet import defer +from twisted.python.failure import Failure +from twisted.web._newclient import ResponseDone + +from synapse.handlers.oidc_handler import ( + MappingException, + OidcError, + OidcHandler, + OidcMappingProvider, +) +from synapse.types import UserID + +from tests.unittest import HomeserverTestCase, override_config + + +@attr.s +class FakeResponse: + code = attr.ib() + body = attr.ib() + phrase = attr.ib() + + def deliverBody(self, protocol): + protocol.dataReceived(self.body) + protocol.connectionLost(Failure(ResponseDone())) + + +# These are a few constants that are used as config parameters in the tests. +ISSUER = "https://issuer/" +CLIENT_ID = "test-client-id" +CLIENT_SECRET = "test-client-secret" +BASE_URL = "https://synapse/" +CALLBACK_URL = BASE_URL + "_synapse/oidc/callback" +SCOPES = ["openid"] + +AUTHORIZATION_ENDPOINT = ISSUER + "authorize" +TOKEN_ENDPOINT = ISSUER + "token" +USERINFO_ENDPOINT = ISSUER + "userinfo" +WELL_KNOWN = ISSUER + ".well-known/openid-configuration" +JWKS_URI = ISSUER + ".well-known/jwks.json" + +# config for common cases +COMMON_CONFIG = { + "discover": False, + "authorization_endpoint": AUTHORIZATION_ENDPOINT, + "token_endpoint": TOKEN_ENDPOINT, + "jwks_uri": JWKS_URI, +} + + +# The cookie name and path don't really matter, just that it has to be coherent +# between the callback & redirect handlers. +COOKIE_NAME = b"oidc_session" +COOKIE_PATH = "/_synapse/oidc" + +MockedMappingProvider = Mock(OidcMappingProvider) + + +def simple_async_mock(return_value=None, raises=None): + # AsyncMock is not available in python3.5, this mimics part of its behaviour + async def cb(*args, **kwargs): + if raises: + raise raises + return return_value + + return Mock(side_effect=cb) + + +async def get_json(url): + # Mock get_json calls to handle jwks & oidc discovery endpoints + if url == WELL_KNOWN: + # Minimal discovery document, as defined in OpenID.Discovery + # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata + return { + "issuer": ISSUER, + "authorization_endpoint": AUTHORIZATION_ENDPOINT, + "token_endpoint": TOKEN_ENDPOINT, + "jwks_uri": JWKS_URI, + "userinfo_endpoint": USERINFO_ENDPOINT, + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + } + elif url == JWKS_URI: + return {"keys": []} + + +class OidcHandlerTestCase(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + + self.http_client = Mock(spec=["get_json"]) + self.http_client.get_json.side_effect = get_json + self.http_client.user_agent = "Synapse Test" + + config = self.default_config() + config["public_baseurl"] = BASE_URL + oidc_config = config.get("oidc_config", {}) + oidc_config["enabled"] = True + oidc_config["client_id"] = CLIENT_ID + oidc_config["client_secret"] = CLIENT_SECRET + oidc_config["issuer"] = ISSUER + oidc_config["scopes"] = SCOPES + oidc_config["user_mapping_provider"] = { + "module": __name__ + ".MockedMappingProvider" + } + config["oidc_config"] = oidc_config + + hs = self.setup_test_homeserver( + http_client=self.http_client, + proxied_http_client=self.http_client, + config=config, + ) + + self.handler = OidcHandler(hs) + + return hs + + def metadata_edit(self, values): + return patch.dict(self.handler._provider_metadata, values) + + def assertRenderedError(self, error, error_description=None): + args = self.handler._render_error.call_args[0] + self.assertEqual(args[1], error) + if error_description is not None: + self.assertEqual(args[2], error_description) + # Reset the render_error mock + self.handler._render_error.reset_mock() + + def test_config(self): + """Basic config correctly sets up the callback URL and client auth correctly.""" + self.assertEqual(self.handler._callback_url, CALLBACK_URL) + self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID) + self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET) + + @override_config({"oidc_config": {"discover": True}}) + @defer.inlineCallbacks + def test_discovery(self): + """The handler should discover the endpoints from OIDC discovery document.""" + # This would throw if some metadata were invalid + metadata = yield defer.ensureDeferred(self.handler.load_metadata()) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + + self.assertEqual(metadata.issuer, ISSUER) + self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT) + self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT) + self.assertEqual(metadata.jwks_uri, JWKS_URI) + # FIXME: it seems like authlib does not have that defined in its metadata models + # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT) + + # subsequent calls should be cached + self.http_client.reset_mock() + yield defer.ensureDeferred(self.handler.load_metadata()) + self.http_client.get_json.assert_not_called() + + @override_config({"oidc_config": COMMON_CONFIG}) + @defer.inlineCallbacks + def test_no_discovery(self): + """When discovery is disabled, it should not try to load from discovery document.""" + yield defer.ensureDeferred(self.handler.load_metadata()) + self.http_client.get_json.assert_not_called() + + @override_config({"oidc_config": COMMON_CONFIG}) + @defer.inlineCallbacks + def test_load_jwks(self): + """JWKS loading is done once (then cached) if used.""" + jwks = yield defer.ensureDeferred(self.handler.load_jwks()) + self.http_client.get_json.assert_called_once_with(JWKS_URI) + self.assertEqual(jwks, {"keys": []}) + + # subsequent calls should be cached… + self.http_client.reset_mock() + yield defer.ensureDeferred(self.handler.load_jwks()) + self.http_client.get_json.assert_not_called() + + # …unless forced + self.http_client.reset_mock() + yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.http_client.get_json.assert_called_once_with(JWKS_URI) + + # Throw if the JWKS uri is missing + with self.metadata_edit({"jwks_uri": None}): + with self.assertRaises(RuntimeError): + yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + + # Return empty key set if JWKS are not used + self.handler._scopes = [] # not asking the openid scope + self.http_client.get_json.reset_mock() + jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.http_client.get_json.assert_not_called() + self.assertEqual(jwks, {"keys": []}) + + @override_config({"oidc_config": COMMON_CONFIG}) + def test_validate_config(self): + """Provider metadatas are extensively validated.""" + h = self.handler + + # Default test config does not throw + h._validate_metadata() + + with self.metadata_edit({"issuer": None}): + self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + + with self.metadata_edit({"issuer": "http://insecure/"}): + self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + + with self.metadata_edit({"issuer": "https://invalid/?because=query"}): + self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + + with self.metadata_edit({"authorization_endpoint": None}): + self.assertRaisesRegex( + ValueError, "authorization_endpoint", h._validate_metadata + ) + + with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}): + self.assertRaisesRegex( + ValueError, "authorization_endpoint", h._validate_metadata + ) + + with self.metadata_edit({"token_endpoint": None}): + self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata) + + with self.metadata_edit({"token_endpoint": "http://insecure/token"}): + self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata) + + with self.metadata_edit({"jwks_uri": None}): + self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata) + + with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}): + self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata) + + with self.metadata_edit({"response_types_supported": ["id_token"]}): + self.assertRaisesRegex( + ValueError, "response_types_supported", h._validate_metadata + ) + + with self.metadata_edit( + {"token_endpoint_auth_methods_supported": ["client_secret_basic"]} + ): + # should not throw, as client_secret_basic is the default auth method + h._validate_metadata() + + with self.metadata_edit( + {"token_endpoint_auth_methods_supported": ["client_secret_post"]} + ): + self.assertRaisesRegex( + ValueError, + "token_endpoint_auth_methods_supported", + h._validate_metadata, + ) + + # Tests for configs that the userinfo endpoint + self.assertFalse(h._uses_userinfo) + h._scopes = [] # do not request the openid scope + self.assertTrue(h._uses_userinfo) + self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata) + + with self.metadata_edit( + {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None} + ): + # Shouldn't raise with a valid userinfo, even without + h._validate_metadata() + + @override_config({"oidc_config": {"skip_verification": True}}) + def test_skip_verification(self): + """Provider metadata validation can be disabled by config.""" + with self.metadata_edit({"issuer": "http://insecure"}): + # This should not throw + self.handler._validate_metadata() + + @defer.inlineCallbacks + def test_redirect_request(self): + """The redirect request has the right arguments & generates a valid session cookie.""" + req = Mock(spec=["addCookie", "redirect", "finish"]) + yield defer.ensureDeferred( + self.handler.handle_redirect_request(req, b"http://client/redirect") + ) + url = req.redirect.call_args[0][0] + url = urlparse(url) + auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) + + self.assertEqual(url.scheme, auth_endpoint.scheme) + self.assertEqual(url.netloc, auth_endpoint.netloc) + self.assertEqual(url.path, auth_endpoint.path) + + params = parse_qs(url.query) + self.assertEqual(params["redirect_uri"], [CALLBACK_URL]) + self.assertEqual(params["response_type"], ["code"]) + self.assertEqual(params["scope"], [" ".join(SCOPES)]) + self.assertEqual(params["client_id"], [CLIENT_ID]) + self.assertEqual(len(params["state"]), 1) + self.assertEqual(len(params["nonce"]), 1) + + # Check what is in the cookie + # note: python3.5 mock does not have the .called_once() method + calls = req.addCookie.call_args_list + self.assertEqual(len(calls), 1) # called once + # For some reason, call.args does not work with python3.5 + args = calls[0][0] + kwargs = calls[0][1] + self.assertEqual(args[0], COOKIE_NAME) + self.assertEqual(kwargs["path"], COOKIE_PATH) + cookie = args[1] + + macaroon = pymacaroons.Macaroon.deserialize(cookie) + state = self.handler._get_value_from_macaroon(macaroon, "state") + nonce = self.handler._get_value_from_macaroon(macaroon, "nonce") + redirect = self.handler._get_value_from_macaroon( + macaroon, "client_redirect_url" + ) + + self.assertEqual(params["state"], [state]) + self.assertEqual(params["nonce"], [nonce]) + self.assertEqual(redirect, "http://client/redirect") + + @defer.inlineCallbacks + def test_callback_error(self): + """Errors from the provider returned in the callback are displayed.""" + self.handler._render_error = Mock() + request = Mock(args={}) + request.args[b"error"] = [b"invalid_client"] + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_client", "") + + request.args[b"error_description"] = [b"some description"] + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_client", "some description") + + @defer.inlineCallbacks + def test_callback(self): + """Code callback works and display errors if something went wrong. + + A lot of scenarios are tested here: + - when the callback works, with userinfo from ID token + - when the user mapping fails + - when ID token verification fails + - when the callback works, with userinfo fetched from the userinfo endpoint + - when the userinfo fetching fails + - when the code exchange fails + """ + token = { + "type": "bearer", + "id_token": "id_token", + "access_token": "access_token", + } + userinfo = { + "sub": "foo", + "preferred_username": "bar", + } + user_id = UserID("foo", "domain.org") + self.handler._render_error = Mock(return_value=None) + self.handler._exchange_code = simple_async_mock(return_value=token) + self.handler._parse_id_token = simple_async_mock(return_value=userinfo) + self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) + self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + self.handler._auth_handler.complete_sso_login = simple_async_mock() + request = Mock(spec=["args", "getCookie", "addCookie"]) + + code = "code" + state = "state" + nonce = "nonce" + client_redirect_url = "http://client/redirect" + session = self.handler._generate_oidc_session_token( + state=state, nonce=nonce, client_redirect_url=client_redirect_url, + ) + request.getCookie.return_value = session + + request.args = {} + request.args[b"code"] = [code.encode("utf-8")] + request.args[b"state"] = [state.encode("utf-8")] + + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + + self.handler._auth_handler.complete_sso_login.assert_called_once_with( + user_id, request, client_redirect_url, + ) + self.handler._exchange_code.assert_called_once_with(code) + self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) + self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._fetch_userinfo.assert_not_called() + self.handler._render_error.assert_not_called() + + # Handle mapping errors + self.handler._map_userinfo_to_user = simple_async_mock( + raises=MappingException() + ) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("mapping_error") + self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + + # Handle ID token errors + self.handler._parse_id_token = simple_async_mock(raises=Exception()) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_token") + + self.handler._auth_handler.complete_sso_login.reset_mock() + self.handler._exchange_code.reset_mock() + self.handler._parse_id_token.reset_mock() + self.handler._map_userinfo_to_user.reset_mock() + self.handler._fetch_userinfo.reset_mock() + + # With userinfo fetching + self.handler._scopes = [] # do not ask the "openid" scope + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + + self.handler._auth_handler.complete_sso_login.assert_called_once_with( + user_id, request, client_redirect_url, + ) + self.handler._exchange_code.assert_called_once_with(code) + self.handler._parse_id_token.assert_not_called() + self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._fetch_userinfo.assert_called_once_with(token) + self.handler._render_error.assert_not_called() + + # Handle userinfo fetching error + self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("fetch_error") + + # Handle code exchange failure + self.handler._exchange_code = simple_async_mock( + raises=OidcError("invalid_request") + ) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_request") + + @defer.inlineCallbacks + def test_callback_session(self): + """The callback verifies the session presence and validity""" + self.handler._render_error = Mock(return_value=None) + request = Mock(spec=["args", "getCookie", "addCookie"]) + + # Missing cookie + request.args = {} + request.getCookie.return_value = None + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("missing_session", "No session cookie found") + + # Missing session parameter + request.args = {} + request.getCookie.return_value = "session" + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_request", "State parameter is missing") + + # Invalid cookie + request.args = {} + request.args[b"state"] = [b"state"] + request.getCookie.return_value = "session" + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_session") + + # Mismatching session + session = self.handler._generate_oidc_session_token( + state="state", nonce="nonce", client_redirect_url="http://client/redirect", + ) + request.args = {} + request.args[b"state"] = [b"mismatching state"] + request.getCookie.return_value = session + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("mismatching_session") + + # Valid session + request.args = {} + request.args[b"state"] = [b"state"] + request.getCookie.return_value = session + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_request") + + @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}}) + @defer.inlineCallbacks + def test_exchange_code(self): + """Code exchange behaves correctly and handles various error scenarios.""" + token = {"type": "bearer"} + token_json = json.dumps(token).encode("utf-8") + self.http_client.request = simple_async_mock( + return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) + ) + code = "code" + ret = yield defer.ensureDeferred(self.handler._exchange_code(code)) + kwargs = self.http_client.request.call_args[1] + + self.assertEqual(ret, token) + self.assertEqual(kwargs["method"], "POST") + self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + + args = parse_qs(kwargs["data"].decode("utf-8")) + self.assertEqual(args["grant_type"], ["authorization_code"]) + self.assertEqual(args["code"], [code]) + self.assertEqual(args["client_id"], [CLIENT_ID]) + self.assertEqual(args["client_secret"], [CLIENT_SECRET]) + self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) + + # Test error handling + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=400, + phrase=b"Bad Request", + body=b'{"error": "foo", "error_description": "bar"}', + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "foo") + self.assertEqual(exc.exception.error_description, "bar") + + # Internal server error with no JSON body + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=500, phrase=b"Internal Server Error", body=b"Not JSON", + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "server_error") + + # Internal server error with JSON body + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=500, + phrase=b"Internal Server Error", + body=b'{"error": "internal_server_error"}', + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "internal_server_error") + + # 4xx error without "error" field + self.http_client.request = simple_async_mock( + return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "server_error") + + # 2xx error with "error" field + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=200, phrase=b"OK", body=b'{"error": "some_error"}', + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "some_error") diff --git a/tox.ini b/tox.ini index c699f3e46a..ad1902d47d 100644 --- a/tox.ini +++ b/tox.ini @@ -185,6 +185,7 @@ commands = mypy \ synapse/handlers/auth.py \ synapse/handlers/cas_handler.py \ synapse/handlers/directory.py \ + synapse/handlers/oidc_handler.py \ synapse/handlers/presence.py \ synapse/handlers/saml_handler.py \ synapse/handlers/sync.py \ -- cgit 1.4.1 From 4734a7bbe4d08d68c5f04dd76cd5bcfb4cd9b6be Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 14 May 2020 14:01:39 +0100 Subject: Move EventStream handling into default ReplicationDataHandler (#7493) This is so that the logic can happen on both master and workers when we move event persistence out. --- changelog.d/7493.misc | 1 + synapse/app/generic_worker.py | 33 ++------------------------------- synapse/replication/tcp/client.py | 37 +++++++++++++++++++++++++++++++++---- synapse/server.py | 2 +- synapse/server.pyi | 3 +++ 5 files changed, 40 insertions(+), 36 deletions(-) create mode 100644 changelog.d/7493.misc (limited to 'synapse/server.pyi') diff --git a/changelog.d/7493.misc b/changelog.d/7493.misc new file mode 100644 index 0000000000..575c55a99b --- /dev/null +++ b/changelog.d/7493.misc @@ -0,0 +1 @@ +Move EventStream handling into default ReplicationDataHandler. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index bccb1140b2..2e3add7ac5 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -26,7 +26,6 @@ from twisted.web.resource import NoResource import synapse import synapse.events -from synapse.api.constants import EventTypes from synapse.api.errors import HttpResponseException, SynapseError from synapse.api.urls import ( CLIENT_API_PREFIX, @@ -81,11 +80,6 @@ from synapse.replication.tcp.streams import ( ToDeviceStream, TypingStream, ) -from synapse.replication.tcp.streams.events import ( - EventsStream, - EventsStreamEventRow, - EventsStreamRow, -) from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.client.v1 import events from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet @@ -633,7 +627,7 @@ class GenericWorkerServer(HomeServer): class GenericWorkerReplicationHandler(ReplicationDataHandler): def __init__(self, hs): - super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore()) + super(GenericWorkerReplicationHandler, self).__init__(hs) self.store = hs.get_datastore() self.typing_handler = hs.get_typing_handler() @@ -659,30 +653,7 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): stream_name, token, rows ) - if stream_name == EventsStream.NAME: - # We shouldn't get multiple rows per token for events stream, so - # we don't need to optimise this for multiple rows. - for row in rows: - if row.type != EventsStreamEventRow.TypeId: - continue - assert isinstance(row, EventsStreamRow) - - event = await self.store.get_event( - row.data.event_id, allow_rejected=True - ) - if event.rejected_reason: - continue - - extra_users = () - if event.type == EventTypes.Member: - extra_users = (event.state_key,) - max_token = self.store.get_room_max_stream_ordering() - self.notifier.on_new_room_event( - event, token, max_token, extra_users - ) - - await self.pusher_pool.on_new_notifications(token, token) - elif stream_name == PushRulesStream.NAME: + if stream_name == PushRulesStream.NAME: self.notifier.on_new_event( "push_rules_key", token, users=[row.user_id for row in rows] ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 20cb8a654f..28826302f5 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,12 +16,17 @@ """ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple from twisted.internet.protocol import ReconnectingClientFactory -from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.api.constants import EventTypes from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol +from synapse.replication.tcp.streams.events import ( + EventsStream, + EventsStreamEventRow, + EventsStreamRow, +) if TYPE_CHECKING: from synapse.server import HomeServer @@ -83,8 +88,10 @@ class ReplicationDataHandler: to handle updates in additional ways. """ - def __init__(self, store: BaseSlavedStore): - self.store = store + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastore() + self.pusher_pool = hs.get_pusherpool() + self.notifier = hs.get_notifier() async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -102,6 +109,28 @@ class ReplicationDataHandler: """ self.store.process_replication_rows(stream_name, instance_name, token, rows) + if stream_name == EventsStream.NAME: + # We shouldn't get multiple rows per token for events stream, so + # we don't need to optimise this for multiple rows. + for row in rows: + if row.type != EventsStreamEventRow.TypeId: + continue + assert isinstance(row, EventsStreamRow) + + event = await self.store.get_event( + row.data.event_id, allow_rejected=True + ) + if event.rejected_reason: + continue + + extra_users = () # type: Tuple[str, ...] + if event.type == EventTypes.Member: + extra_users = (event.state_key,) + max_token = self.store.get_room_max_stream_ordering() + self.notifier.on_new_room_event(event, token, max_token, extra_users) + + await self.pusher_pool.on_new_notifications(token, token) + async def on_position(self, stream_name: str, instance_name: str, token: int): self.store.process_replication_rows(stream_name, instance_name, token, []) diff --git a/synapse/server.py b/synapse/server.py index b4aea81e24..c530f1aa1a 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -581,7 +581,7 @@ class HomeServer(object): return ReplicationStreamer(self) def build_replication_data_handler(self): - return ReplicationDataHandler(self.get_datastore()) + return ReplicationDataHandler(self) def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 31a9cc0389..9e7fad7e6e 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -19,6 +19,7 @@ import synapse.handlers.room_member import synapse.handlers.set_password import synapse.http.client import synapse.notifier +import synapse.push.pusherpool import synapse.replication.tcp.client import synapse.replication.tcp.handler import synapse.rest.media.v1.media_repository @@ -133,3 +134,5 @@ class HomeServer(object): pass def get_macaroon_generator(self) -> synapse.handlers.auth.MacaroonGenerator: pass + def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool: + pass -- cgit 1.4.1 From 1531b214fc57714c14046a8f66c7b5fe5ec5dcdd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 22 May 2020 14:21:54 +0100 Subject: Add ability to wait for replication streams (#7542) The idea here is that if an instance persists an event via the replication HTTP API it can return before we receive that event over replication, which can lead to races where code assumes that persisting an event immediately updates various caches (e.g. current state of the room). Most of Synapse doesn't hit such races, so we don't do the waiting automagically, instead we do so where necessary to avoid unnecessary delays. We may decide to change our minds here if it turns out there are a lot of subtle races going on. People probably want to look at this commit by commit. --- changelog.d/7542.misc | 1 + synapse/handlers/federation.py | 33 ++++++--- synapse/handlers/message.py | 36 ++++++--- synapse/handlers/room.py | 65 +++++++++++----- synapse/handlers/room_member.py | 65 ++++++++++------ synapse/handlers/room_member_worker.py | 11 +-- synapse/replication/http/federation.py | 13 +++- synapse/replication/http/membership.py | 14 ++-- synapse/replication/http/send_event.py | 4 +- synapse/replication/http/streams.py | 5 +- synapse/replication/tcp/client.py | 90 ++++++++++++++++++++++- synapse/rest/admin/rooms.py | 10 ++- synapse/rest/client/v1/room.py | 20 +++-- synapse/rest/client/v2_alpha/relations.py | 2 +- synapse/server.py | 5 ++ synapse/server.pyi | 5 ++ synapse/server_notices/server_notices_manager.py | 6 +- synapse/storage/data_stores/main/events_worker.py | 6 +- synapse/storage/data_stores/main/roommember.py | 2 + tests/federation/test_complexity.py | 8 +- tests/handlers/test_typing.py | 5 +- tests/storage/test_cleanup_extrems.py | 4 +- tests/storage/test_event_metrics.py | 2 +- tests/test_federation.py | 4 +- 24 files changed, 304 insertions(+), 112 deletions(-) create mode 100644 changelog.d/7542.misc (limited to 'synapse/server.pyi') diff --git a/changelog.d/7542.misc b/changelog.d/7542.misc new file mode 100644 index 0000000000..7dd9b4823b --- /dev/null +++ b/changelog.d/7542.misc @@ -0,0 +1 @@ +Add ability to wait for replication streams. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bb03cc9add..e354c803db 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -126,6 +126,7 @@ class FederationHandler(BaseHandler): self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() + self._replication = hs.get_replication_data_handler() self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client( hs @@ -1221,7 +1222,7 @@ class FederationHandler(BaseHandler): async def do_invite_join( self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict - ) -> None: + ) -> Tuple[str, int]: """ Attempts to join the `joinee` to the room `room_id` via the servers contained in `target_hosts`. @@ -1304,15 +1305,23 @@ class FederationHandler(BaseHandler): room_id=room_id, room_version=room_version_obj, ) - await self._persist_auth_tree( + max_stream_id = await self._persist_auth_tree( origin, auth_chain, state, event, room_version_obj ) + # We wait here until this instance has seen the events come down + # replication (if we're using replication) as the below uses caches. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position( + "master", "events", max_stream_id + ) + # Check whether this room is the result of an upgrade of a room we already know # about. If so, migrate over user information predecessor = await self.store.get_room_predecessor(room_id) if not predecessor or not isinstance(predecessor.get("room_id"), str): - return + return event.event_id, max_stream_id old_room_id = predecessor["room_id"] logger.debug( "Found predecessor for %s during remote join: %s", room_id, old_room_id @@ -1325,6 +1334,7 @@ class FederationHandler(BaseHandler): ) logger.debug("Finished joining %s to %s", joinee, room_id) + return event.event_id, max_stream_id finally: room_queue = self.room_queues[room_id] del self.room_queues[room_id] @@ -1554,7 +1564,7 @@ class FederationHandler(BaseHandler): async def do_remotely_reject_invite( self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict - ) -> EventBase: + ) -> Tuple[EventBase, int]: origin, event, room_version = await self._make_and_verify_event( target_hosts, room_id, user_id, "leave", content=content ) @@ -1574,9 +1584,9 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(target_hosts, event) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify([(event, context)]) + stream_id = await self.persist_events_and_notify([(event, context)]) - return event + return event, stream_id async def _make_and_verify_event( self, @@ -1888,7 +1898,7 @@ class FederationHandler(BaseHandler): state: List[EventBase], event: EventBase, room_version: RoomVersion, - ) -> None: + ) -> int: """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. Persists the event separately. Notifies about the persisted events @@ -1982,7 +1992,7 @@ class FederationHandler(BaseHandler): event, old_state=state ) - await self.persist_events_and_notify([(event, new_event_context)]) + return await self.persist_events_and_notify([(event, new_event_context)]) async def _prep_event( self, @@ -2835,7 +2845,7 @@ class FederationHandler(BaseHandler): self, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, - ) -> None: + ) -> int: """Persists events and tells the notifier/pushers about them, if necessary. @@ -2845,11 +2855,12 @@ class FederationHandler(BaseHandler): backfilling or not """ if self.config.worker_app: - await self._send_events_to_master( + result = await self._send_events_to_master( store=self.store, event_and_contexts=event_and_contexts, backfilled=backfilled, ) + return result["max_stream_id"] else: max_stream_id = await self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled @@ -2864,6 +2875,8 @@ class FederationHandler(BaseHandler): for event, _ in event_and_contexts: await self._notify_persisted_event(event, max_stream_id) + return max_stream_id + async def _notify_persisted_event( self, event: EventBase, max_stream_id: int ) -> None: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8f362896a2..f445e2aa2a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import Optional, Tuple from six import iteritems, itervalues, string_types @@ -42,6 +42,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder +from synapse.events import EventBase from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -630,7 +631,9 @@ class EventCreationHandler(object): msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) - async def send_nonmember_event(self, requester, event, context, ratelimit=True): + async def send_nonmember_event( + self, requester, event, context, ratelimit=True + ) -> int: """ Persists and notifies local clients and federation of an event. @@ -639,6 +642,9 @@ class EventCreationHandler(object): context (Context) the context of the event. ratelimit (bool): Whether to rate limit this send. is_guest (bool): Whether the sender is a guest. + + Return: + The stream_id of the persisted event. """ if event.type == EventTypes.Member: raise SynapseError( @@ -659,7 +665,7 @@ class EventCreationHandler(object): ) return prev_state - await self.handle_new_client_event( + return await self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit ) @@ -688,7 +694,7 @@ class EventCreationHandler(object): async def create_and_send_nonmember_event( self, requester, event_dict, ratelimit=True, txn_id=None - ): + ) -> Tuple[EventBase, int]: """ Creates an event, then sends it. @@ -711,10 +717,10 @@ class EventCreationHandler(object): spam_error = "Spam is not permitted here" raise SynapseError(403, spam_error, Codes.FORBIDDEN) - await self.send_nonmember_event( + stream_id = await self.send_nonmember_event( requester, event, context, ratelimit=ratelimit ) - return event + return event, stream_id @measure_func("create_new_client_event") @defer.inlineCallbacks @@ -774,7 +780,7 @@ class EventCreationHandler(object): @measure_func("handle_new_client_event") async def handle_new_client_event( self, requester, event, context, ratelimit=True, extra_users=[] - ): + ) -> int: """Processes a new event. This includes checking auth, persisting it, notifying users, sending to remote servers, etc. @@ -787,6 +793,9 @@ class EventCreationHandler(object): context (EventContext) ratelimit (bool) extra_users (list(UserID)): Any extra users to notify about event + + Return: + The stream_id of the persisted event. """ if event.is_state() and (event.type, event.state_key) == ( @@ -827,7 +836,7 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. if self.config.worker_app: - await self.send_event_to_master( + result = await self.send_event_to_master( event_id=event.event_id, store=self.store, requester=requester, @@ -836,14 +845,17 @@ class EventCreationHandler(object): ratelimit=ratelimit, extra_users=extra_users, ) + stream_id = result["stream_id"] + event.internal_metadata.stream_ordering = stream_id success = True - return + return stream_id - await self.persist_and_notify_client_event( + stream_id = await self.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) success = True + return stream_id finally: if not success: # Ensure that we actually remove the entries in the push actions @@ -886,7 +898,7 @@ class EventCreationHandler(object): async def persist_and_notify_client_event( self, requester, event, context, ratelimit=True, extra_users=[] - ): + ) -> int: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -1076,6 +1088,8 @@ class EventCreationHandler(object): # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) + return event_stream_id + async def _bump_active_time(self, user): try: presence = self.hs.get_presence_handler() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 13850ba672..2698a129ca 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -22,6 +22,7 @@ import logging import math import string from collections import OrderedDict +from typing import Tuple from six import iteritems, string_types @@ -518,7 +519,7 @@ class RoomCreationHandler(BaseHandler): async def create_room( self, requester, config, ratelimit=True, creator_join_profile=None - ): + ) -> Tuple[dict, int]: """ Creates a new room. Args: @@ -535,9 +536,9 @@ class RoomCreationHandler(BaseHandler): `avatar_url` and/or `displayname`. Returns: - Deferred[dict]: - a dict containing the keys `room_id` and, if an alias was - requested, `room_alias`. + First, a dict containing the keys `room_id` and, if an alias + was, requested, `room_alias`. Secondly, the stream_id of the + last persisted event. Raises: SynapseError if the room ID couldn't be stored, or something went horribly wrong. @@ -669,7 +670,7 @@ class RoomCreationHandler(BaseHandler): # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier - await self._send_events_for_new_room( + last_stream_id = await self._send_events_for_new_room( requester, room_id, preset_config=preset_config, @@ -683,7 +684,10 @@ class RoomCreationHandler(BaseHandler): if "name" in config: name = config["name"] - await self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Name, @@ -697,7 +701,10 @@ class RoomCreationHandler(BaseHandler): if "topic" in config: topic = config["topic"] - await self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Topic, @@ -715,7 +722,7 @@ class RoomCreationHandler(BaseHandler): if is_direct: content["is_direct"] = is_direct - await self.room_member_handler.update_membership( + _, last_stream_id = await self.room_member_handler.update_membership( requester, UserID.from_string(invitee), room_id, @@ -729,7 +736,7 @@ class RoomCreationHandler(BaseHandler): id_access_token = invite_3pid.get("id_access_token") # optional address = invite_3pid["address"] medium = invite_3pid["medium"] - await self.hs.get_room_member_handler().do_3pid_invite( + last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, medium, @@ -745,7 +752,7 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() - return result + return result, last_stream_id async def _send_events_for_new_room( self, @@ -758,7 +765,13 @@ class RoomCreationHandler(BaseHandler): room_alias=None, power_level_content_override=None, # Doesn't apply when initial state has power level state event content creator_join_profile=None, - ): + ) -> int: + """Sends the initial events into a new room. + + Returns: + The stream_id of the last event persisted. + """ + def create(etype, content, **kwargs): e = {"type": etype, "content": content} @@ -767,12 +780,16 @@ class RoomCreationHandler(BaseHandler): return e - async def send(etype, content, **kwargs): + async def send(etype, content, **kwargs) -> int: event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) - await self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( creator, event, ratelimit=False ) + return last_stream_id config = RoomCreationHandler.PRESETS_DICT[preset_config] @@ -797,7 +814,9 @@ class RoomCreationHandler(BaseHandler): # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - await send(etype=EventTypes.PowerLevels, content=pl_content) + last_sent_stream_id = await send( + etype=EventTypes.PowerLevels, content=pl_content + ) else: power_level_content = { "users": {creator_id: 100}, @@ -830,33 +849,39 @@ class RoomCreationHandler(BaseHandler): if power_level_content_override: power_level_content.update(power_level_content_override) - await send(etype=EventTypes.PowerLevels, content=power_level_content) + last_sent_stream_id = await send( + etype=EventTypes.PowerLevels, content=power_level_content + ) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) if (EventTypes.JoinRules, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} ) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.RoomHistoryVisibility, content={"history_visibility": config["history_visibility"]}, ) if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} ) for (etype, state_key), content in initial_state.items(): - await send(etype=etype, state_key=state_key, content=content) + last_sent_stream_id = await send( + etype=etype, state_key=state_key, content=content + ) + + return last_sent_stream_id async def _generate_room_id( self, creator_id: str, is_public: str, room_version: RoomVersion, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index e51e1c32fe..691b6705b2 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,7 +17,7 @@ import abc import logging -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple from six.moves import http_client @@ -84,7 +84,7 @@ class RoomMemberHandler(object): room_id: str, user: UserID, content: dict, - ) -> Optional[dict]: + ) -> Tuple[str, int]: """Try and join a room that this server is not in Args: @@ -104,7 +104,7 @@ class RoomMemberHandler(object): room_id: str, target: UserID, content: dict, - ) -> dict: + ) -> Tuple[Optional[str], int]: """Attempt to reject an invite for a room this server is not in. If we fail to do so we locally mark the invite as rejected. @@ -154,7 +154,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> EventBase: + ) -> Tuple[str, int]: user_id = target.to_string() if content is None: @@ -187,9 +187,10 @@ class RoomMemberHandler(object): ) if duplicate is not None: # Discard the new event since this membership change is a no-op. - return duplicate + _, stream_id = await self.store.get_event_ordering(duplicate.event_id) + return duplicate.event_id, stream_id - await self.event_creation_handler.handle_new_client_event( + stream_id = await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target], ratelimit=ratelimit ) @@ -213,7 +214,7 @@ class RoomMemberHandler(object): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target, room_id) - return event + return event.event_id, stream_id async def copy_room_tags_and_direct_to_room( self, old_room_id, new_room_id, user_id @@ -263,7 +264,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> Union[EventBase, Optional[dict]]: + ) -> Tuple[Optional[str], int]: key = (room_id,) with (await self.member_linearizer.queue(key)): @@ -294,7 +295,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> Union[EventBase, Optional[dict]]: + ) -> Tuple[Optional[str], int]: content_specified = bool(content) if content is None: content = {} @@ -398,7 +399,13 @@ class RoomMemberHandler(object): same_membership = old_membership == effective_membership_state same_sender = requester.user.to_string() == old_state.sender if same_sender and same_membership and same_content: - return old_state + _, stream_id = await self.store.get_event_ordering( + old_state.event_id + ) + return ( + old_state.event_id, + stream_id, + ) if old_membership in ["ban", "leave"] and action == "kick": raise AuthError(403, "The target user is not in the room") @@ -705,7 +712,7 @@ class RoomMemberHandler(object): requester: Requester, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> None: + ) -> int: if self.config.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: @@ -737,11 +744,11 @@ class RoomMemberHandler(object): ) if invitee: - await self.update_membership( + _, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: - await self._make_and_store_3pid_invite( + stream_id = await self._make_and_store_3pid_invite( requester, id_server, medium, @@ -752,6 +759,8 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) + return stream_id + async def _make_and_store_3pid_invite( self, requester: Requester, @@ -762,7 +771,7 @@ class RoomMemberHandler(object): user: UserID, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> None: + ) -> int: room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" @@ -817,7 +826,10 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) - await self.event_creation_handler.create_and_send_nonmember_event( + ( + event, + stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, @@ -835,6 +847,7 @@ class RoomMemberHandler(object): ratelimit=False, txn_id=txn_id, ) + return stream_id async def _is_host_in_room( self, current_state_ids: Dict[Tuple[str, str], str] @@ -916,7 +929,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): room_id: str, user: UserID, content: dict, - ) -> None: + ) -> Tuple[str, int]: """Implements RoomMemberHandler._remote_join """ # filter ourselves out of remote_room_hosts: do_invite_join ignores it @@ -945,7 +958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # join dance for now, since we're kinda implicitly checking # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. - await self.federation_handler.do_invite_join( + event_id, stream_id = await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user.to_string(), content ) await self._user_joined_room(user, room_id) @@ -955,14 +968,14 @@ class RoomMemberMasterHandler(RoomMemberHandler): if self.hs.config.limit_remote_rooms.enabled: if too_complex is False: # We checked, and we're under the limit. - return + return event_id, stream_id # Check again, but with the local state events too_complex = await self._is_local_room_too_complex(room_id) if too_complex is False: # We're under the limit. - return + return event_id, stream_id # The room is too large. Leave. requester = types.create_requester(user, None, False, None) @@ -975,6 +988,8 @@ class RoomMemberMasterHandler(RoomMemberHandler): errcode=Codes.RESOURCE_LIMIT_EXCEEDED, ) + return event_id, stream_id + async def _remote_reject_invite( self, requester: Requester, @@ -982,15 +997,15 @@ class RoomMemberMasterHandler(RoomMemberHandler): room_id: str, target: UserID, content: dict, - ) -> dict: + ) -> Tuple[Optional[str], int]: """Implements RoomMemberHandler._remote_reject_invite """ fed_handler = self.federation_handler try: - ret = await fed_handler.do_remotely_reject_invite( + event, stream_id = await fed_handler.do_remotely_reject_invite( remote_room_hosts, room_id, target.to_string(), content=content, ) - return ret + return event.event_id, stream_id except Exception as e: # if we were unable to reject the exception, just mark # it as rejected on our end and plough ahead. @@ -1000,8 +1015,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): # logger.warning("Failed to reject invite: %s", e) - await self.store.locally_reject_invite(target.to_string(), room_id) - return {} + stream_id = await self.store.locally_reject_invite( + target.to_string(), room_id + ) + return None, stream_id async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 5c776cc0be..02e0c4103d 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import List, Optional +from typing import List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler @@ -43,7 +43,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): room_id: str, user: UserID, content: dict, - ) -> Optional[dict]: + ) -> Tuple[str, int]: """Implements RoomMemberHandler._remote_join """ if len(remote_room_hosts) == 0: @@ -59,7 +59,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): await self._user_joined_room(user, room_id) - return ret + return ret["event_id"], ret["stream_id"] async def _remote_reject_invite( self, @@ -68,16 +68,17 @@ class RoomMemberWorkerHandler(RoomMemberHandler): room_id: str, target: UserID, content: dict, - ) -> dict: + ) -> Tuple[Optional[str], int]: """Implements RoomMemberHandler._remote_reject_invite """ - return await self._remote_reject_client( + ret = await self._remote_reject_client( requester=requester, remote_room_hosts=remote_room_hosts, room_id=room_id, user_id=target.to_string(), content=content, ) + return ret["event_id"], ret["stream_id"] async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 7e23b565b9..c287c4e269 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): """Handles events newly received from federation, including persisting and - notifying. + notifying. Returns the maximum stream ID of the persisted events. The API looks like: @@ -46,6 +46,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): "context": { .. serialized event context .. }, }], "backfilled": false + } + + 200 OK + + { + "max_stream_id": 32443, + } """ NAME = "fed_send_events" @@ -115,11 +122,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) - await self.federation_handler.persist_events_and_notify( + max_stream_id = await self.federation_handler.persist_events_and_notify( event_and_contexts, backfilled ) - return 200, {} + return 200, {"max_stream_id": max_stream_id} class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 3577611fd7..050fd34562 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -76,11 +76,11 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): logger.info("remote_join: %s into room: %s", user_id, room_id) - await self.federation_handler.do_invite_join( + event_id, stream_id = await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user_id, event_content ) - return 200, {} + return 200, {"event_id": event_id, "stream_id": stream_id} class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): @@ -136,10 +136,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) try: - event = await self.federation_handler.do_remotely_reject_invite( + event, stream_id = await self.federation_handler.do_remotely_reject_invite( remote_room_hosts, room_id, user_id, event_content, ) - ret = event.get_pdu_json() + event_id = event.event_id except Exception as e: # if we were unable to reject the exception, just mark # it as rejected on our end and plough ahead. @@ -149,10 +149,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # logger.warning("Failed to reject invite: %s", e) - await self.store.locally_reject_invite(user_id, room_id) - ret = {} + stream_id = await self.store.locally_reject_invite(user_id, room_id) + event_id = None - return 200, ret + return 200, {"event_id": event_id, "stream_id": stream_id} class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index b74b088ff4..c981723c1a 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -119,11 +119,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - await self.event_creation_handler.persist_and_notify_client_event( + stream_id = await self.event_creation_handler.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) - return 200, {} + return 200, {"stream_id": stream_id} def register_servlets(hs, http_server): diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index b705a8e16c..bde97eef32 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -51,10 +51,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): super().__init__(hs) self._instance_name = hs.get_instance_name() - - # We pull the streams from the replication handler (if we try and make - # them ourselves we end up in an import loop). - self.streams = hs.get_tcp_replication().get_streams() + self.streams = hs.get_replication_streams() @staticmethod def _serialize_payload(stream_name, from_token, upto_token): diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 28826302f5..508ad1b720 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -14,19 +14,23 @@ # limitations under the License. """A replication client for use by synapse workers. """ - +import heapq import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple +from twisted.internet.defer import Deferred from twisted.internet.protocol import ReconnectingClientFactory from synapse.api.constants import EventTypes +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, EventsStreamRow, ) +from synapse.util.async_helpers import timeout_deferred +from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer @@ -35,6 +39,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# How long we allow callers to wait for replication updates before timing out. +_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30 + + class DirectTcpReplicationClientFactory(ReconnectingClientFactory): """Factory for building connections to the master. Will reconnect if the connection is lost. @@ -92,6 +100,16 @@ class ReplicationDataHandler: self.store = hs.get_datastore() self.pusher_pool = hs.get_pusherpool() self.notifier = hs.get_notifier() + self._reactor = hs.get_reactor() + self._clock = hs.get_clock() + self._streams = hs.get_replication_streams() + self._instance_name = hs.get_instance_name() + + # Map from stream to list of deferreds waiting for the stream to + # arrive at a particular position. The lists are sorted by stream position. + self._streams_to_waiters = ( + {} + ) # type: Dict[str, List[Tuple[int, Deferred[None]]]] async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -131,8 +149,76 @@ class ReplicationDataHandler: await self.pusher_pool.on_new_notifications(token, token) + # Notify any waiting deferreds. The list is ordered by position so we + # just iterate through the list until we reach a position that is + # greater than the received row position. + waiting_list = self._streams_to_waiters.get(stream_name, []) + + # Index of first item with a position after the current token, i.e we + # have called all deferreds before this index. If not overwritten by + # loop below means either a) no items in list so no-op or b) all items + # in list were called and so the list should be cleared. Setting it to + # `len(list)` works for both cases. + index_of_first_deferred_not_called = len(waiting_list) + + for idx, (position, deferred) in enumerate(waiting_list): + if position <= token: + try: + with PreserveLoggingContext(): + deferred.callback(None) + except Exception: + # The deferred has been cancelled or timed out. + pass + else: + # The list is sorted by position so we don't need to continue + # checking any futher entries in the list. + index_of_first_deferred_not_called = idx + break + + # Drop all entries in the waiting list that were called in the above + # loop. (This maintains the order so no need to resort) + waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] + async def on_position(self, stream_name: str, instance_name: str, token: int): self.store.process_replication_rows(stream_name, instance_name, token, []) def on_remote_server_up(self, server: str): """Called when get a new REMOTE_SERVER_UP command.""" + + async def wait_for_stream_position( + self, instance_name: str, stream_name: str, position: int + ): + """Wait until this instance has received updates up to and including + the given stream position. + """ + + if instance_name == self._instance_name: + # We don't get told about updates written by this process, and + # anyway in that case we don't need to wait. + return + + current_position = self._streams[stream_name].current_token(self._instance_name) + if position <= current_position: + # We're already past the position + return + + # Create a new deferred that times out after N seconds, as we don't want + # to wedge here forever. + deferred = Deferred() + deferred = timeout_deferred( + deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor + ) + + waiting_list = self._streams_to_waiters.setdefault(stream_name, []) + + # We insert into the list using heapq as it is more efficient than + # pushing then resorting each time. + heapq.heappush(waiting_list, (position, deferred)) + + # We measure here to get in flight counts and average waiting time. + with Measure(self._clock, "repl.wait_for_stream_position"): + logger.info("Waiting for repl stream %r to reach %s", stream_name, position) + await make_deferred_yieldable(deferred) + logger.info( + "Finished waiting for repl stream %r to reach %s", stream_name, position + ) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 7d40001988..0a13e1ed34 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -59,6 +59,7 @@ class ShutdownRoomRestServlet(RestServlet): self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() + self._replication = hs.get_replication_data_handler() async def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request) @@ -73,7 +74,7 @@ class ShutdownRoomRestServlet(RestServlet): message = content.get("message", self.DEFAULT_MESSAGE) room_name = content.get("room_name", "Content Violation Notification") - info = await self._room_creation_handler.create_room( + info, stream_id = await self._room_creation_handler.create_room( room_creator_requester, config={ "preset": "public_chat", @@ -94,6 +95,13 @@ class ShutdownRoomRestServlet(RestServlet): # desirable in case the first attempt at blocking the room failed below. await self.store.block_room(room_id, requester_user_id) + # We now wait for the create room to come back in via replication so + # that we can assume that all the joins/invites have propogated before + # we try and auto join below. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position("master", "events", stream_id) + users = await self.state.get_current_users_in_room(room_id) kicked_users = [] failed_to_kick_users = [] diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 6b5830cc3f..105e0cf4d2 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -93,7 +93,7 @@ class RoomCreateRestServlet(TransactionRestServlet): async def on_POST(self, request): requester = await self.auth.get_user_by_req(request) - info = await self._room_creation_handler.create_room( + info, _ = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) @@ -202,7 +202,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): if event_type == EventTypes.Member: membership = content.get("membership", None) - event = await self.room_member_handler.update_membership( + event_id, _ = await self.room_member_handler.update_membership( requester, target=UserID.from_string(state_key), room_id=room_id, @@ -210,14 +210,18 @@ class RoomStateEventRestServlet(TransactionRestServlet): content=content, ) else: - event = await self.event_creation_handler.create_and_send_nonmember_event( + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) + event_id = event.event_id ret = {} # type: dict - if event: - set_tag("event_id", event.event_id) - ret = {"event_id": event.event_id} + if event_id: + set_tag("event_id", event_id) + ret = {"event_id": event_id} return 200, ret @@ -247,7 +251,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - event = await self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) @@ -781,7 +785,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - event = await self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Redaction, diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 63f07b63da..89002ffbff 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -111,7 +111,7 @@ class RelationSendServlet(RestServlet): "sender": requester.user.to_string(), } - event = await self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict=event_dict, txn_id=txn_id ) diff --git a/synapse/server.py b/synapse/server.py index c530f1aa1a..ca2deb49bb 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -90,6 +90,7 @@ from synapse.push.pusherpool import PusherPool from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer +from synapse.replication.tcp.streams import STREAMS_MAP from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, @@ -210,6 +211,7 @@ class HomeServer(object): "storage", "replication_streamer", "replication_data_handler", + "replication_streams", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -583,6 +585,9 @@ class HomeServer(object): def build_replication_data_handler(self): return ReplicationDataHandler(self) + def build_replication_streams(self): + return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()} + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9e7fad7e6e..fe8024d2d4 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,3 +1,5 @@ +from typing import Dict + import twisted.internet import synapse.api.auth @@ -28,6 +30,7 @@ import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage from synapse.events.builder import EventBuilderFactory +from synapse.replication.tcp.streams import Stream class HomeServer(object): @property @@ -136,3 +139,5 @@ class HomeServer(object): pass def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool: pass + def get_replication_streams(self) -> Dict[str, Stream]: + pass diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 999c621b92..bf2454c01c 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -83,10 +83,10 @@ class ServerNoticesManager(object): if state_key is not None: event_dict["state_key"] = state_key - res = await self._event_creation_handler.create_and_send_nonmember_event( + event, _ = await self._event_creation_handler.create_and_send_nonmember_event( requester, event_dict, ratelimit=False ) - return res + return event @cached() async def get_or_create_notice_room_for_user(self, user_id): @@ -143,7 +143,7 @@ class ServerNoticesManager(object): } requester = create_requester(self.server_notices_mxid) - info = await self._room_creation_handler.create_room( + info, _ = await self._room_creation_handler.create_room( requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 9130b74eb5..b880a71782 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -1289,12 +1289,12 @@ class EventsWorkerStore(SQLBaseStore): async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ - to_1, so_1 = await self._get_event_ordering(event_id1) - to_2, so_2 = await self._get_event_ordering(event_id2) + to_1, so_1 = await self.get_event_ordering(event_id1) + to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cachedInlineCallbacks(max_entries=5000) - def _get_event_ordering(self, event_id): + def get_event_ordering(self, event_id): res = yield self.db.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 1e9c850152..7c5ca81ae0 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -1069,6 +1069,8 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): with self._stream_id_gen.get_next() as stream_ordering: yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) + return stream_ordering + def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 94980733c4..0c9987be54 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -79,7 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) - handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) d = handler._remote_join( None, @@ -115,7 +117,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=defer.succeed(None)) - handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) # Artificially raise the complexity self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 51e2b37218..2fa8d4739b 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -86,7 +86,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): reactor.pump((1000,)) hs = self.setup_test_homeserver( - notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring + notifier=Mock(), + http_client=mock_federation_client, + keyring=mock_keyring, + replication_streams={}, ) hs.datastores = datastores diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 0e04b2cf92..43425c969a 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -39,7 +39,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] def run_background_update(self): @@ -261,7 +261,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() homeserver.config.user_consent_version = self.CONSENT_VERSION diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index a7b7fd36d3..a7b85004e5 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -33,7 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): events = [(3, 2), (6, 2), (4, 6)] for event_count, extrems in events: - info = self.get_success(room_creator.create_room(requester, {})) + info, _ = self.get_success(room_creator.create_room(requester, {})) room_id = info["room_id"] last_event = None diff --git a/tests/test_federation.py b/tests/test_federation.py index 13ff14863e..c5099dd039 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -28,13 +28,13 @@ class MessageAcceptTests(unittest.HomeserverTestCase): user_id = UserID("us", "test") our_user = Requester(user_id, None, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room = ensureDeferred( + room_deferred = ensureDeferred( room_creator.create_room( our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False ) ) self.reactor.advance(0.1) - self.room_id = self.successResultOf(room)["room_id"] + self.room_id = self.successResultOf(room_deferred)[0]["room_id"] self.store = self.homeserver.get_datastore() -- cgit 1.4.1