summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/constants.py3
-rw-r--r--synapse/api/urls.py4
-rwxr-xr-xsynapse/app/homeserver.py56
-rw-r--r--synapse/config/_base.py4
-rw-r--r--synapse/config/key.py3
-rw-r--r--synapse/config/registration.py9
-rw-r--r--synapse/config/tls.py149
-rw-r--r--synapse/crypto/context_factory.py21
-rw-r--r--synapse/crypto/keyclient.py149
-rw-r--r--synapse/crypto/keyring.py30
-rw-r--r--synapse/federation/federation_server.py4
-rw-r--r--synapse/federation/transport/client.py132
-rw-r--r--synapse/federation/transport/server.py46
-rw-r--r--synapse/handlers/acme.py147
-rw-r--r--synapse/handlers/identity.py9
-rw-r--r--synapse/handlers/room.py1
-rw-r--r--synapse/handlers/room_list.py15
-rw-r--r--synapse/handlers/user_directory.py1
-rw-r--r--synapse/http/endpoint.py280
-rw-r--r--synapse/http/federation/__init__.py14
-rw-r--r--synapse/http/federation/matrix_federation_agent.py124
-rw-r--r--synapse/http/federation/srv_resolver.py169
-rw-r--r--synapse/http/matrixfederationclient.py227
-rw-r--r--synapse/python_dependencies.py19
-rw-r--r--synapse/rest/client/v2_alpha/register.py16
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/_base.py158
-rw-r--r--synapse/storage/client_ips.py156
-rw-r--r--synapse/storage/engines/__init__.py2
-rw-r--r--synapse/storage/engines/postgres.py14
-rw-r--r--synapse/storage/engines/sqlite.py (renamed from synapse/storage/engines/sqlite3.py)9
-rw-r--r--synapse/storage/events.py13
-rw-r--r--synapse/storage/pusher.py9
-rw-r--r--synapse/storage/schema/delta/53/user_ips_index.sql26
-rw-r--r--synapse/storage/user_directory.py55
-rw-r--r--synapse/util/async_helpers.py4
36 files changed, 1300 insertions, 783 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py

index 87bc1cb53d..46c4b4b9dc 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py
@@ -68,6 +68,7 @@ class EventTypes(object): Aliases = "m.room.aliases" Redaction = "m.room.redaction" ThirdPartyInvite = "m.room.third_party_invite" + Encryption = "m.room.encryption" RoomHistoryVisibility = "m.room.history_visibility" CanonicalAlias = "m.room.canonical_alias" @@ -128,4 +129,4 @@ class UserTypes(object): 'admin' and 'guest' users should also be UserTypes. Normal users are type None """ SUPPORT = "support" - ALL_USER_TYPES = (SUPPORT) + ALL_USER_TYPES = (SUPPORT,) diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index f78695b657..8102176653 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py
@@ -24,7 +24,9 @@ from synapse.config import ConfigError CLIENT_PREFIX = "/_matrix/client/api/v1" CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha" -FEDERATION_PREFIX = "/_matrix/federation/v1" +FEDERATION_PREFIX = "/_matrix/federation" +FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1" +FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2" STATIC_PREFIX = "/_matrix/static" WEB_CLIENT_PREFIX = "/_matrix/client" CONTENT_REPO_PREFIX = "/_matrix/content" diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index f3ac3d19f0..ffc49d77cc 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py
@@ -13,10 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import gc import logging import os import sys +import traceback from six import iteritems @@ -324,17 +326,12 @@ def setup(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - tls_server_context_factory = context_factory.ServerContextFactory(config) - tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) - database_engine = create_engine(config.database_config) config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection hs = SynapseHomeServer( config.server_name, db_config=config.database_config, - tls_server_context_factory=tls_server_context_factory, - tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, @@ -361,12 +358,53 @@ def setup(config_options): logger.info("Database prepared in %s.", config.database_config['name']) hs.setup() - hs.start_listening() + @defer.inlineCallbacks def start(): - hs.get_pusherpool().start() - hs.get_datastore().start_profiling() - hs.get_datastore().start_doing_background_updates() + try: + # Check if the certificate is still valid. + cert_days_remaining = hs.config.is_disk_cert_valid() + + if hs.config.acme_enabled: + # If ACME is enabled, we might need to provision a certificate + # before starting. + acme = hs.get_acme_handler() + + # Start up the webservices which we will respond to ACME + # challenges with. + yield acme.start_listening() + + # We want to reprovision if cert_days_remaining is None (meaning no + # certificate exists), or the days remaining number it returns + # is less than our re-registration threshold. + if (cert_days_remaining is None) or ( + not cert_days_remaining > hs.config.acme_reprovision_threshold + ): + yield acme.provision_certificate() + + # Read the certificate from disk and build the context factories for + # TLS. + hs.config.read_certificate_from_disk() + hs.tls_server_context_factory = context_factory.ServerContextFactory(config) + hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory( + config + ) + + # It is now safe to start your Synapse. + hs.start_listening() + hs.get_pusherpool().start() + hs.get_datastore().start_profiling() + hs.get_datastore().start_doing_background_updates() + except Exception as e: + # If a DeferredList failed (like in listening on the ACME listener), + # we need to print the subfailure explicitly. + if isinstance(e, defer.FirstError): + e.subFailure.printTraceback(sys.stderr) + sys.exit(1) + + # Something else went wrong when starting. Print it and bail out. + traceback.print_exc(file=sys.stderr) + sys.exit(1) reactor.callWhenRunning(start) diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index fd2d6d52ef..5858fb92b4 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py
@@ -367,7 +367,7 @@ class Config(object): if not keys_directory: keys_directory = os.path.dirname(config_files[-1]) - config_dir_path = os.path.abspath(keys_directory) + self.config_dir_path = os.path.abspath(keys_directory) specified_config = {} for config_file in config_files: @@ -379,7 +379,7 @@ class Config(object): server_name = specified_config["server_name"] config_string = self.generate_config( - config_dir_path=config_dir_path, + config_dir_path=self.config_dir_path, data_dir_path=os.getcwd(), server_name=server_name, generate_secrets=False, diff --git a/synapse/config/key.py b/synapse/config/key.py
index 3b11f0cfa9..dce4b19a2d 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py
@@ -83,9 +83,6 @@ class KeyConfig(Config): # a secret which is used to sign access tokens. If none is specified, # the registration_shared_secret is used, if one is given; otherwise, # a secret key is derived from the signing key. - # - # Note that changing this will invalidate any active access tokens, so - # all clients will have to log back in. %(macaroon_secret_key)s # Used to enable access token expiration. diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 6c2b543b8c..fe520d6855 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py
@@ -50,6 +50,10 @@ class RegistrationConfig(Config): raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) + self.disable_msisdn_registration = ( + config.get("disable_msisdn_registration", False) + ) + def default_config(self, generate_secrets=False, **kwargs): if generate_secrets: registration_shared_secret = 'registration_shared_secret: "%s"' % ( @@ -70,6 +74,11 @@ class RegistrationConfig(Config): # - email # - msisdn + # Explicitly disable asking for MSISDNs from the registration + # flow (overrides registrations_require_3pid if MSISDNs are set as required) + # + # disable_msisdn_registration = True + # Mandate that users are only allowed to associate certain formats of # 3PIDs with accounts on this server. # diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index fef1ea99cb..a75e233aa0 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py
@@ -13,68 +13,110 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os -import subprocess +from datetime import datetime from hashlib import sha256 from unpaddedbase64 import encode_base64 from OpenSSL import crypto -from ._base import Config +from synapse.config._base import Config -GENERATE_DH_PARAMS = False +logger = logging.getLogger() class TlsConfig(Config): def read_config(self, config): - self.tls_certificate = self.read_tls_certificate( - config.get("tls_certificate_path") - ) - self.tls_certificate_file = config.get("tls_certificate_path") + acme_config = config.get("acme", {}) + self.acme_enabled = acme_config.get("enabled", False) + self.acme_url = acme_config.get( + "url", "https://acme-v01.api.letsencrypt.org/directory" + ) + self.acme_port = acme_config.get("port", 8449) + self.acme_bind_addresses = acme_config.get("bind_addresses", ["127.0.0.1"]) + self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30) + + self.tls_certificate_file = os.path.abspath(config.get("tls_certificate_path")) + self.tls_private_key_file = os.path.abspath(config.get("tls_private_key_path")) + self._original_tls_fingerprints = config["tls_fingerprints"] + self.tls_fingerprints = list(self._original_tls_fingerprints) self.no_tls = config.get("no_tls", False) - if self.no_tls: - self.tls_private_key = None - else: - self.tls_private_key = self.read_tls_private_key( - config.get("tls_private_key_path") - ) + # This config option applies to non-federation HTTP clients + # (e.g. for talking to recaptcha, identity servers, and such) + # It should never be used in production, and is intended for + # use only when running tests. + self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get( + "use_insecure_ssl_client_just_for_testing_do_not_use" + ) - self.tls_dh_params_path = self.check_file( - config.get("tls_dh_params_path"), "tls_dh_params" + self.tls_certificate = None + self.tls_private_key = None + + def is_disk_cert_valid(self): + """ + Is the certificate we have on disk valid, and if so, for how long? + + Returns: + int: Days remaining of certificate validity. + None: No certificate exists. + """ + if not os.path.exists(self.tls_certificate_file): + return None + + try: + with open(self.tls_certificate_file, 'rb') as f: + cert_pem = f.read() + except Exception: + logger.exception("Failed to read existing certificate off disk!") + raise + + try: + tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) + except Exception: + logger.exception("Failed to parse existing certificate off disk!") + raise + + # YYYYMMDDhhmmssZ -- in UTC + expires_on = datetime.strptime( + tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ" ) + now = datetime.utcnow() + days_remaining = (expires_on - now).days + return days_remaining + + def read_certificate_from_disk(self): + """ + Read the certificates from disk. + """ + self.tls_certificate = self.read_tls_certificate(self.tls_certificate_file) + + if not self.no_tls: + self.tls_private_key = self.read_tls_private_key(self.tls_private_key_file) - self.tls_fingerprints = config["tls_fingerprints"] + self.tls_fingerprints = list(self._original_tls_fingerprints) # Check that our own certificate is included in the list of fingerprints # and include it if it is not. x509_certificate_bytes = crypto.dump_certificate( - crypto.FILETYPE_ASN1, - self.tls_certificate + crypto.FILETYPE_ASN1, self.tls_certificate ) sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest()) sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints) if sha256_fingerprint not in sha256_fingerprints: self.tls_fingerprints.append({u"sha256": sha256_fingerprint}) - # This config option applies to non-federation HTTP clients - # (e.g. for talking to recaptcha, identity servers, and such) - # It should never be used in production, and is intended for - # use only when running tests. - self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get( - "use_insecure_ssl_client_just_for_testing_do_not_use" - ) - def default_config(self, config_dir_path, server_name, **kwargs): base_key_name = os.path.join(config_dir_path, server_name) tls_certificate_path = base_key_name + ".tls.crt" tls_private_key_path = base_key_name + ".tls.key" - tls_dh_params_path = base_key_name + ".tls.dh" - return """\ + return ( + """\ # PEM encoded X509 certificate for TLS. # You can replace the self-signed certificate that synapse # autogenerates on launch with your own SSL certificate + key pair @@ -85,9 +127,6 @@ class TlsConfig(Config): # PEM encoded private key for TLS tls_private_key_path: "%(tls_private_key_path)s" - # PEM dh parameters for ephemeral keys - tls_dh_params_path: "%(tls_dh_params_path)s" - # Don't bind to the https port no_tls: False @@ -118,7 +157,24 @@ class TlsConfig(Config): # tls_fingerprints: [] # tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}] - """ % locals() + + ## Support for ACME certificate auto-provisioning. + # acme: + # enabled: false + ## ACME path. + ## If you only want to test, use the staging url: + ## https://acme-staging.api.letsencrypt.org/directory + # url: 'https://acme-v01.api.letsencrypt.org/directory' + ## Port number (to listen for the HTTP-01 challenge). + ## Using port 80 requires utilising something like authbind, or proxying to it. + # port: 8449 + ## Hosts to bind to. + # bind_addresses: ['127.0.0.1'] + ## How many days remaining on a certificate before it is renewed. + # reprovision_threshold: 30 + """ + % locals() + ) def read_tls_certificate(self, cert_path): cert_pem = self.read_file(cert_path, "tls_certificate") @@ -131,7 +187,6 @@ class TlsConfig(Config): def generate_files(self, config): tls_certificate_path = config["tls_certificate_path"] tls_private_key_path = config["tls_private_key_path"] - tls_dh_params_path = config["tls_dh_params_path"] if not self.path_exists(tls_private_key_path): with open(tls_private_key_path, "wb") as private_key_file: @@ -165,31 +220,3 @@ class TlsConfig(Config): cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) certificate_file.write(cert_pem) - - if not self.path_exists(tls_dh_params_path): - if GENERATE_DH_PARAMS: - subprocess.check_call([ - "openssl", "dhparam", - "-outform", "PEM", - "-out", tls_dh_params_path, - "2048" - ]) - else: - with open(tls_dh_params_path, "w") as dh_params_file: - dh_params_file.write( - "2048-bit DH parameters taken from rfc3526\n" - "-----BEGIN DH PARAMETERS-----\n" - "MIIBCAKCAQEA///////////JD9qiIWjC" - "NMTGYouA3BzRKQJOCIpnzHQCC76mOxOb\n" - "IlFKCHmONATd75UZs806QxswKwpt8l8U" - "N0/hNW1tUcJF5IW1dmJefsb0TELppjft\n" - "awv/XLb0Brft7jhr+1qJn6WunyQRfEsf" - "5kkoZlHs5Fs9wgB8uKFjvwWY2kg2HFXT\n" - "mmkWP6j9JM9fg2VdI9yjrZYcYvNWIIVS" - "u57VKQdwlpZtZww1Tkq8mATxdGwIyhgh\n" - "fDKQXkYuNs474553LBgOhgObJ4Oi7Aei" - "j7XFXfBvTFLJ3ivL9pVYFxg5lUl86pVq\n" - "5RXSJhiY+gUQFXKOWoqsqmj/////////" - "/wIBAg==\n" - "-----END DH PARAMETERS-----\n" - ) diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 02b76dfcfb..286ad80100 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py
@@ -17,6 +17,7 @@ from zope.interface import implementer from OpenSSL import SSL, crypto from twisted.internet._sslverify import _defaultCurveName +from twisted.internet.abstract import isIPAddress, isIPv6Address from twisted.internet.interfaces import IOpenSSLClientConnectionCreator from twisted.internet.ssl import CertificateOptions, ContextFactory from twisted.python.failure import Failure @@ -46,8 +47,10 @@ class ServerContextFactory(ContextFactory): if not config.no_tls: context.use_privatekey(config.tls_private_key) - context.load_tmp_dh(config.tls_dh_params_path) - context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH") + # https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/ + context.set_cipher_list( + "ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1" + ) def getContext(self): return self._context @@ -96,8 +99,14 @@ class ClientTLSOptions(object): def __init__(self, hostname, ctx): self._ctx = ctx - self._hostname = hostname - self._hostnameBytes = _idnaBytes(hostname) + + if isIPAddress(hostname) or isIPv6Address(hostname): + self._hostnameBytes = hostname.encode('ascii') + self._sendSNI = False + else: + self._hostnameBytes = _idnaBytes(hostname) + self._sendSNI = True + ctx.set_info_callback( _tolerateErrors(self._identityVerifyingInfoCallback) ) @@ -109,7 +118,9 @@ class ClientTLSOptions(object): return connection def _identityVerifyingInfoCallback(self, connection, where, ret): - if where & SSL.SSL_CB_HANDSHAKE_START: + # Literal IPv4 and IPv6 addresses are not permitted + # as host names according to the RFCs + if where & SSL.SSL_CB_HANDSHAKE_START and self._sendSNI: connection.set_tlsext_host_name(self._hostnameBytes) diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py deleted file mode 100644
index d40e4b8591..0000000000 --- a/synapse/crypto/keyclient.py +++ /dev/null
@@ -1,149 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from six.moves import urllib - -from canonicaljson import json - -from twisted.internet import defer, reactor -from twisted.internet.error import ConnectError -from twisted.internet.protocol import Factory -from twisted.names.error import DomainError -from twisted.web.http import HTTPClient - -from synapse.http.endpoint import matrix_federation_endpoint -from synapse.util import logcontext - -logger = logging.getLogger(__name__) - -KEY_API_V2 = "/_matrix/key/v2/server/%s" - - -@defer.inlineCallbacks -def fetch_server_key(server_name, tls_client_options_factory, key_id): - """Fetch the keys for a remote server.""" - - factory = SynapseKeyClientFactory() - factory.path = KEY_API_V2 % (urllib.parse.quote(key_id), ) - factory.host = server_name - endpoint = matrix_federation_endpoint( - reactor, server_name, tls_client_options_factory, timeout=30 - ) - - for i in range(5): - try: - with logcontext.PreserveLoggingContext(): - protocol = yield endpoint.connect(factory) - server_response, server_certificate = yield protocol.remote_key - defer.returnValue((server_response, server_certificate)) - except SynapseKeyClientError as e: - logger.warn("Error getting key for %r: %s", server_name, e) - if e.status.startswith(b"4"): - # Don't retry for 4xx responses. - raise IOError("Cannot get key for %r" % server_name) - except (ConnectError, DomainError) as e: - logger.warn("Error getting key for %r: %s", server_name, e) - except Exception: - logger.exception("Error getting key for %r", server_name) - raise IOError("Cannot get key for %r" % server_name) - - -class SynapseKeyClientError(Exception): - """The key wasn't retrieved from the remote server.""" - status = None - pass - - -class SynapseKeyClientProtocol(HTTPClient): - """Low level HTTPS client which retrieves an application/json response from - the server and extracts the X.509 certificate for the remote peer from the - SSL connection.""" - - timeout = 30 - - def __init__(self): - self.remote_key = defer.Deferred() - self.host = None - self._peer = None - - def connectionMade(self): - self._peer = self.transport.getPeer() - logger.debug("Connected to %s", self._peer) - - if not isinstance(self.path, bytes): - self.path = self.path.encode('ascii') - - if not isinstance(self.host, bytes): - self.host = self.host.encode('ascii') - - self.sendCommand(b"GET", self.path) - if self.host: - self.sendHeader(b"Host", self.host) - self.endHeaders() - self.timer = reactor.callLater( - self.timeout, - self.on_timeout - ) - - def errback(self, error): - if not self.remote_key.called: - self.remote_key.errback(error) - - def callback(self, result): - if not self.remote_key.called: - self.remote_key.callback(result) - - def handleStatus(self, version, status, message): - if status != b"200": - # logger.info("Non-200 response from %s: %s %s", - # self.transport.getHost(), status, message) - error = SynapseKeyClientError( - "Non-200 response %r from %r" % (status, self.host) - ) - error.status = status - self.errback(error) - self.transport.abortConnection() - - def handleResponse(self, response_body_bytes): - try: - json_response = json.loads(response_body_bytes) - except ValueError: - # logger.info("Invalid JSON response from %s", - # self.transport.getHost()) - self.transport.abortConnection() - return - - certificate = self.transport.getPeerCertificate() - self.callback((json_response, certificate)) - self.transport.abortConnection() - self.timer.cancel() - - def on_timeout(self): - logger.debug( - "Timeout waiting for response from %s: %s", - self.host, self._peer, - ) - self.errback(IOError("Timeout waiting for response")) - self.transport.abortConnection() - - -class SynapseKeyClientFactory(Factory): - def protocol(self): - protocol = SynapseKeyClientProtocol() - protocol.path = self.path - protocol.host = self.host - return protocol diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 515ebbc148..3a96980bed 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py
@@ -14,10 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib import logging from collections import namedtuple +from six.moves import urllib + from signedjson.key import ( decode_verify_key_bytes, encode_verify_key_base64, @@ -30,13 +31,11 @@ from signedjson.sign import ( signature_ids, verify_signed_json, ) -from unpaddedbase64 import decode_base64, encode_base64 +from unpaddedbase64 import decode_base64 -from OpenSSL import crypto from twisted.internet import defer from synapse.api.errors import Codes, SynapseError -from synapse.crypto.keyclient import fetch_server_key from synapse.util import logcontext, unwrapFirstError from synapse.util.logcontext import ( LoggingContext, @@ -503,31 +502,16 @@ class Keyring(object): if requested_key_id in keys: continue - (response, tls_certificate) = yield fetch_server_key( - server_name, self.hs.tls_client_options_factory, requested_key_id + response = yield self.client.get_json( + destination=server_name, + path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id), + ignore_backoff=True, ) if (u"signatures" not in response or server_name not in response[u"signatures"]): raise KeyLookupError("Key response not signed by remote server") - if "tls_fingerprints" not in response: - raise KeyLookupError("Key response missing TLS fingerprints") - - certificate_bytes = crypto.dump_certificate( - crypto.FILETYPE_ASN1, tls_certificate - ) - sha256_fingerprint = hashlib.sha256(certificate_bytes).digest() - sha256_fingerprint_b64 = encode_base64(sha256_fingerprint) - - response_sha256_fingerprints = set() - for fingerprint in response[u"tls_fingerprints"]: - if u"sha256" in fingerprint: - response_sha256_fingerprints.add(fingerprint[u"sha256"]) - - if sha256_fingerprint_b64 not in response_sha256_fingerprints: - raise KeyLookupError("TLS certificate not allowed by fingerprints") - response_keys = yield self.process_v2_response( from_server=server_name, requested_ids=[requested_key_id], diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 98722ae543..37d29e7027 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -369,13 +369,13 @@ class FederationServer(FederationBase): }) @defer.inlineCallbacks - def on_invite_request(self, origin, content): + def on_invite_request(self, origin, content, room_version): pdu = event_from_pdu_json(content) origin_host, _ = parse_server_name(origin) yield self.check_server_matches_acl(origin_host, pdu.room_id) ret_pdu = yield self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() - defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) + defer.returnValue({"event": ret_pdu.get_pdu_json(time_now)}) @defer.inlineCallbacks def on_send_join_request(self, origin, content): diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index edba5a9808..260178c47b 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py
@@ -21,7 +21,7 @@ from six.moves import urllib from twisted.internet import defer from synapse.api.constants import Membership -from synapse.api.urls import FEDERATION_PREFIX as PREFIX +from synapse.api.urls import FEDERATION_V1_PREFIX from synapse.util.logutils import log_function logger = logging.getLogger(__name__) @@ -51,7 +51,7 @@ class TransportLayerClient(object): logger.debug("get_room_state dest=%s, room=%s", destination, room_id) - path = _create_path(PREFIX, "/state/%s/", room_id) + path = _create_v1_path("/state/%s/", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, ) @@ -73,7 +73,7 @@ class TransportLayerClient(object): logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) - path = _create_path(PREFIX, "/state_ids/%s/", room_id) + path = _create_v1_path("/state_ids/%s/", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, ) @@ -95,7 +95,7 @@ class TransportLayerClient(object): logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) - path = _create_path(PREFIX, "/event/%s/", event_id) + path = _create_v1_path("/event/%s/", event_id) return self.client.get_json(destination, path=path, timeout=timeout) @log_function @@ -121,7 +121,7 @@ class TransportLayerClient(object): # TODO: raise? return - path = _create_path(PREFIX, "/backfill/%s/", room_id) + path = _create_v1_path("/backfill/%s/", room_id) args = { "v": event_tuples, @@ -167,7 +167,7 @@ class TransportLayerClient(object): # generated by the json_data_callback. json_data = transaction.get_dict() - path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id) + path = _create_v1_path("/send/%s/", transaction.transaction_id) response = yield self.client.put_json( transaction.destination, @@ -184,7 +184,7 @@ class TransportLayerClient(object): @log_function def make_query(self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False): - path = _create_path(PREFIX, "/query/%s", query_type) + path = _create_v1_path("/query/%s", query_type) content = yield self.client.get_json( destination=destination, @@ -231,7 +231,7 @@ class TransportLayerClient(object): "make_membership_event called with membership='%s', must be one of %s" % (membership, ",".join(valid_memberships)) ) - path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id) + path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id) ignore_backoff = False retry_on_dns_fail = False @@ -258,7 +258,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_join(self, destination, room_id, event_id, content): - path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id) + path = _create_v1_path("/send_join/%s/%s", room_id, event_id) response = yield self.client.put_json( destination=destination, @@ -271,7 +271,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_leave(self, destination, room_id, event_id, content): - path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id) + path = _create_v1_path("/send_leave/%s/%s", room_id, event_id) response = yield self.client.put_json( destination=destination, @@ -290,7 +290,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_invite(self, destination, room_id, event_id, content): - path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id) + path = _create_v1_path("/invite/%s/%s", room_id, event_id) response = yield self.client.put_json( destination=destination, @@ -306,7 +306,7 @@ class TransportLayerClient(object): def get_public_rooms(self, remote_server, limit, since_token, search_filter=None, include_all_networks=False, third_party_instance_id=None): - path = PREFIX + "/publicRooms" + path = _create_v1_path("/publicRooms") args = { "include_all_networks": "true" if include_all_networks else "false", @@ -332,7 +332,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def exchange_third_party_invite(self, destination, room_id, event_dict): - path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,) + path = _create_v1_path("/exchange_third_party_invite/%s", room_id,) response = yield self.client.put_json( destination=destination, @@ -345,7 +345,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def get_event_auth(self, destination, room_id, event_id): - path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id) + path = _create_v1_path("/event_auth/%s/%s", room_id, event_id) content = yield self.client.get_json( destination=destination, @@ -357,7 +357,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_query_auth(self, destination, room_id, event_id, content): - path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id) + path = _create_v1_path("/query_auth/%s/%s", room_id, event_id) content = yield self.client.post_json( destination=destination, @@ -392,7 +392,7 @@ class TransportLayerClient(object): Returns: A dict containg the device keys. """ - path = PREFIX + "/user/keys/query" + path = _create_v1_path("/user/keys/query") content = yield self.client.post_json( destination=destination, @@ -419,7 +419,7 @@ class TransportLayerClient(object): Returns: A dict containg the device keys. """ - path = _create_path(PREFIX, "/user/devices/%s", user_id) + path = _create_v1_path("/user/devices/%s", user_id) content = yield self.client.get_json( destination=destination, @@ -455,7 +455,7 @@ class TransportLayerClient(object): A dict containg the one-time keys. """ - path = PREFIX + "/user/keys/claim" + path = _create_v1_path("/user/keys/claim") content = yield self.client.post_json( destination=destination, @@ -469,7 +469,7 @@ class TransportLayerClient(object): @log_function def get_missing_events(self, destination, room_id, earliest_events, latest_events, limit, min_depth, timeout): - path = _create_path(PREFIX, "/get_missing_events/%s", room_id,) + path = _create_v1_path("/get_missing_events/%s", room_id,) content = yield self.client.post_json( destination=destination, @@ -489,7 +489,7 @@ class TransportLayerClient(object): def get_group_profile(self, destination, group_id, requester_user_id): """Get a group profile """ - path = _create_path(PREFIX, "/groups/%s/profile", group_id,) + path = _create_v1_path("/groups/%s/profile", group_id,) return self.client.get_json( destination=destination, @@ -508,7 +508,7 @@ class TransportLayerClient(object): requester_user_id (str) content (dict): The new profile of the group """ - path = _create_path(PREFIX, "/groups/%s/profile", group_id,) + path = _create_v1_path("/groups/%s/profile", group_id,) return self.client.post_json( destination=destination, @@ -522,7 +522,7 @@ class TransportLayerClient(object): def get_group_summary(self, destination, group_id, requester_user_id): """Get a group summary """ - path = _create_path(PREFIX, "/groups/%s/summary", group_id,) + path = _create_v1_path("/groups/%s/summary", group_id,) return self.client.get_json( destination=destination, @@ -535,7 +535,7 @@ class TransportLayerClient(object): def get_rooms_in_group(self, destination, group_id, requester_user_id): """Get all rooms in a group """ - path = _create_path(PREFIX, "/groups/%s/rooms", group_id,) + path = _create_v1_path("/groups/%s/rooms", group_id,) return self.client.get_json( destination=destination, @@ -548,7 +548,7 @@ class TransportLayerClient(object): content): """Add a room to a group """ - path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,) return self.client.post_json( destination=destination, @@ -562,8 +562,8 @@ class TransportLayerClient(object): config_key, content): """Update room in group """ - path = _create_path( - PREFIX, "/groups/%s/room/%s/config/%s", + path = _create_v1_path( + "/groups/%s/room/%s/config/%s", group_id, room_id, config_key, ) @@ -578,7 +578,7 @@ class TransportLayerClient(object): def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): """Remove a room from a group """ - path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,) return self.client.delete_json( destination=destination, @@ -591,7 +591,7 @@ class TransportLayerClient(object): def get_users_in_group(self, destination, group_id, requester_user_id): """Get users in a group """ - path = _create_path(PREFIX, "/groups/%s/users", group_id,) + path = _create_v1_path("/groups/%s/users", group_id,) return self.client.get_json( destination=destination, @@ -604,7 +604,7 @@ class TransportLayerClient(object): def get_invited_users_in_group(self, destination, group_id, requester_user_id): """Get users that have been invited to a group """ - path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,) + path = _create_v1_path("/groups/%s/invited_users", group_id,) return self.client.get_json( destination=destination, @@ -617,8 +617,8 @@ class TransportLayerClient(object): def accept_group_invite(self, destination, group_id, user_id, content): """Accept a group invite """ - path = _create_path( - PREFIX, "/groups/%s/users/%s/accept_invite", + path = _create_v1_path( + "/groups/%s/users/%s/accept_invite", group_id, user_id, ) @@ -633,7 +633,7 @@ class TransportLayerClient(object): def join_group(self, destination, group_id, user_id, content): """Attempts to join a group """ - path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id) + path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) return self.client.post_json( destination=destination, @@ -646,7 +646,7 @@ class TransportLayerClient(object): def invite_to_group(self, destination, group_id, user_id, requester_user_id, content): """Invite a user to a group """ - path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id) + path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id) return self.client.post_json( destination=destination, @@ -662,7 +662,7 @@ class TransportLayerClient(object): invited. """ - path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id) + path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id) return self.client.post_json( destination=destination, @@ -676,7 +676,7 @@ class TransportLayerClient(object): user_id, content): """Remove a user fron a group """ - path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id) + path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) return self.client.post_json( destination=destination, @@ -693,7 +693,7 @@ class TransportLayerClient(object): kicked from the group. """ - path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id) + path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id) return self.client.post_json( destination=destination, @@ -708,7 +708,7 @@ class TransportLayerClient(object): the attestations """ - path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id) + path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id) return self.client.post_json( destination=destination, @@ -723,12 +723,12 @@ class TransportLayerClient(object): """Update a room entry in a group summary """ if category_id: - path = _create_path( - PREFIX, "/groups/%s/summary/categories/%s/rooms/%s", + path = _create_v1_path( + "/groups/%s/summary/categories/%s/rooms/%s", group_id, category_id, room_id, ) else: - path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,) return self.client.post_json( destination=destination, @@ -744,12 +744,12 @@ class TransportLayerClient(object): """Delete a room entry in a group summary """ if category_id: - path = _create_path( - PREFIX + "/groups/%s/summary/categories/%s/rooms/%s", + path = _create_v1_path( + "/groups/%s/summary/categories/%s/rooms/%s", group_id, category_id, room_id, ) else: - path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,) return self.client.delete_json( destination=destination, @@ -762,7 +762,7 @@ class TransportLayerClient(object): def get_group_categories(self, destination, group_id, requester_user_id): """Get all categories in a group """ - path = _create_path(PREFIX, "/groups/%s/categories", group_id,) + path = _create_v1_path("/groups/%s/categories", group_id,) return self.client.get_json( destination=destination, @@ -775,7 +775,7 @@ class TransportLayerClient(object): def get_group_category(self, destination, group_id, requester_user_id, category_id): """Get category info in a group """ - path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) return self.client.get_json( destination=destination, @@ -789,7 +789,7 @@ class TransportLayerClient(object): content): """Update a category in a group """ - path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) return self.client.post_json( destination=destination, @@ -804,7 +804,7 @@ class TransportLayerClient(object): category_id): """Delete a category in a group """ - path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) return self.client.delete_json( destination=destination, @@ -817,7 +817,7 @@ class TransportLayerClient(object): def get_group_roles(self, destination, group_id, requester_user_id): """Get all roles in a group """ - path = _create_path(PREFIX, "/groups/%s/roles", group_id,) + path = _create_v1_path("/groups/%s/roles", group_id,) return self.client.get_json( destination=destination, @@ -830,7 +830,7 @@ class TransportLayerClient(object): def get_group_role(self, destination, group_id, requester_user_id, role_id): """Get a roles info """ - path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) return self.client.get_json( destination=destination, @@ -844,7 +844,7 @@ class TransportLayerClient(object): content): """Update a role in a group """ - path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) return self.client.post_json( destination=destination, @@ -858,7 +858,7 @@ class TransportLayerClient(object): def delete_group_role(self, destination, group_id, requester_user_id, role_id): """Delete a role in a group """ - path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) return self.client.delete_json( destination=destination, @@ -873,12 +873,12 @@ class TransportLayerClient(object): """Update a users entry in a group """ if role_id: - path = _create_path( - PREFIX, "/groups/%s/summary/roles/%s/users/%s", + path = _create_v1_path( + "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id, ) else: - path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,) + path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,) return self.client.post_json( destination=destination, @@ -893,7 +893,7 @@ class TransportLayerClient(object): content): """Sets the join policy for a group """ - path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,) + path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,) return self.client.put_json( destination=destination, @@ -909,12 +909,12 @@ class TransportLayerClient(object): """Delete a users entry in a group """ if role_id: - path = _create_path( - PREFIX, "/groups/%s/summary/roles/%s/users/%s", + path = _create_v1_path( + "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id, ) else: - path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,) + path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,) return self.client.delete_json( destination=destination, @@ -927,7 +927,7 @@ class TransportLayerClient(object): """Get the groups a list of users are publicising """ - path = PREFIX + "/get_groups_publicised" + path = _create_v1_path("/get_groups_publicised") content = {"user_ids": user_ids} @@ -939,20 +939,22 @@ class TransportLayerClient(object): ) -def _create_path(prefix, path, *args): - """Creates a path from the prefix, path template and args. Ensures that - all args are url encoded. +def _create_v1_path(path, *args): + """Creates a path against V1 federation API from the path template and + args. Ensures that all args are url encoded. Example: - _create_path(PREFIX, "/event/%s/", event_id) + _create_v1_path("/event/%s/", event_id) Args: - prefix (str) path (str): String template for the path args: ([str]): Args to insert into path. Each arg will be url encoded Returns: str """ - return prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args) + return ( + FEDERATION_V1_PREFIX + + path % tuple(urllib.parse.quote(arg, "") for arg in args) + ) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 3553f418f1..4557a9e66e 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py
@@ -21,8 +21,9 @@ import re from twisted.internet import defer import synapse +from synapse.api.constants import RoomVersions from synapse.api.errors import Codes, FederationDeniedError, SynapseError -from synapse.api.urls import FEDERATION_PREFIX as PREFIX +from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( @@ -227,6 +228,8 @@ class BaseFederationServlet(object): """ REQUIRE_AUTH = True + PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version + def __init__(self, handler, authenticator, ratelimiter, server_name): self.handler = handler self.authenticator = authenticator @@ -286,7 +289,7 @@ class BaseFederationServlet(object): return new_func def register(self, server): - pattern = re.compile("^" + PREFIX + self.PATH + "$") + pattern = re.compile("^" + self.PREFIX + self.PATH + "$") for method in ("GET", "PUT", "POST"): code = getattr(self, "on_%s" % (method), None) @@ -488,14 +491,46 @@ class FederationSendJoinServlet(BaseFederationServlet): defer.returnValue((200, content)) -class FederationInviteServlet(BaseFederationServlet): +class FederationV1InviteServlet(BaseFederationServlet): PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)" @defer.inlineCallbacks def on_PUT(self, origin, content, query, context, event_id): + # We don't get a room version, so we have to assume its EITHER v1 or + # v2. This is "fine" as the only difference between V1 and V2 is the + # state resolution algorithm, and we don't use that for processing + # invites + content = yield self.handler.on_invite_request( + origin, content, room_version=RoomVersions.V1, + ) + + # V1 federation API is defined to return a content of `[200, {...}]` + # due to a historical bug. + defer.returnValue((200, (200, content))) + + +class FederationV2InviteServlet(BaseFederationServlet): + PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + @defer.inlineCallbacks + def on_PUT(self, origin, content, query, context, event_id): # TODO(paul): assert that context/event_id parsed from path actually # match those given in content - content = yield self.handler.on_invite_request(origin, content) + + room_version = content["room_version"] + event = content["event"] + invite_room_state = content["invite_room_state"] + + # Synapse expects invite_room_state to be in unsigned, as it is in v1 + # API + + event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state + + content = yield self.handler.on_invite_request( + origin, event, room_version=room_version, + ) defer.returnValue((200, content)) @@ -1263,7 +1298,8 @@ FEDERATION_SERVLET_CLASSES = ( FederationEventServlet, FederationSendJoinServlet, FederationSendLeaveServlet, - FederationInviteServlet, + FederationV1InviteServlet, + FederationV2InviteServlet, FederationQueryAuthServlet, FederationGetMissingEventsServlet, FederationEventAuthServlet, diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py new file mode 100644
index 0000000000..73ea7ed018 --- /dev/null +++ b/synapse/handlers/acme.py
@@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import attr +from zope.interface import implementer + +from twisted.internet import defer +from twisted.internet.endpoints import serverFromString +from twisted.python.filepath import FilePath +from twisted.python.url import URL +from twisted.web import server, static +from twisted.web.resource import Resource + +logger = logging.getLogger(__name__) + +try: + from txacme.interfaces import ICertificateStore + + @attr.s + @implementer(ICertificateStore) + class ErsatzStore(object): + """ + A store that only stores in memory. + """ + + certs = attr.ib(default=attr.Factory(dict)) + + def store(self, server_name, pem_objects): + self.certs[server_name] = [o.as_bytes() for o in pem_objects] + return defer.succeed(None) + + +except ImportError: + # txacme is missing + pass + + +class AcmeHandler(object): + def __init__(self, hs): + self.hs = hs + self.reactor = hs.get_reactor() + + @defer.inlineCallbacks + def start_listening(self): + + # Configure logging for txacme, if you need to debug + # from eliot import add_destinations + # from eliot.twisted import TwistedDestination + # + # add_destinations(TwistedDestination()) + + from txacme.challenges import HTTP01Responder + from txacme.service import AcmeIssuingService + from txacme.endpoint import load_or_create_client_key + from txacme.client import Client + from josepy.jwa import RS256 + + self._store = ErsatzStore() + responder = HTTP01Responder() + + self._issuer = AcmeIssuingService( + cert_store=self._store, + client_creator=( + lambda: Client.from_url( + reactor=self.reactor, + url=URL.from_text(self.hs.config.acme_url), + key=load_or_create_client_key( + FilePath(self.hs.config.config_dir_path) + ), + alg=RS256, + ) + ), + clock=self.reactor, + responders=[responder], + ) + + well_known = Resource() + well_known.putChild(b'acme-challenge', responder.resource) + responder_resource = Resource() + responder_resource.putChild(b'.well-known', well_known) + responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain')) + + srv = server.Site(responder_resource) + + listeners = [] + + for host in self.hs.config.acme_bind_addresses: + logger.info( + "Listening for ACME requests on %s:%s", host, self.hs.config.acme_port + ) + endpoint = serverFromString( + self.reactor, "tcp:%s:interface=%s" % (self.hs.config.acme_port, host) + ) + listeners.append(endpoint.listen(srv)) + + # Make sure we are registered to the ACME server. There's no public API + # for this, it is usually triggered by startService, but since we don't + # want it to control where we save the certificates, we have to reach in + # and trigger the registration machinery ourselves. + self._issuer._registered = False + yield self._issuer._ensure_registered() + + # Return a Deferred that will fire when all the servers have started up. + yield defer.DeferredList(listeners, fireOnOneErrback=True, consumeErrors=True) + + @defer.inlineCallbacks + def provision_certificate(self): + + logger.warning("Reprovisioning %s", self.hs.hostname) + + try: + yield self._issuer.issue_cert(self.hs.hostname) + except Exception: + logger.exception("Fail!") + raise + logger.warning("Reprovisioned %s, saving.", self.hs.hostname) + cert_chain = self._store.certs[self.hs.hostname] + + try: + with open(self.hs.config.tls_private_key_file, "wb") as private_key_file: + for x in cert_chain: + if x.startswith(b"-----BEGIN RSA PRIVATE KEY-----"): + private_key_file.write(x) + + with open(self.hs.config.tls_certificate_file, "wb") as certificate_file: + for x in cert_chain: + if x.startswith(b"-----BEGIN CERTIFICATE-----"): + certificate_file.write(x) + except Exception: + logger.exception("Failed saving!") + raise + + defer.returnValue(True) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 5feb3f22a6..39184f0e22 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py
@@ -167,18 +167,21 @@ class IdentityHandler(BaseHandler): "mxid": mxid, "threepid": threepid, } - headers = {} + # we abuse the federation http client to sign the request, but we have to send it # using the normal http client since we don't want the SRV lookup and want normal # 'browser-like' HTTPS. - self.federation_http_client.sign_request( + auth_headers = self.federation_http_client.build_auth_headers( destination=None, method='POST', url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'), - headers_dict=headers, content=content, destination_is=id_server, ) + headers = { + b"Authorization": auth_headers, + } + try: yield self.http_client.post_json_get_json( url, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 581e96c743..cb8c5f77dd 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -269,6 +269,7 @@ class RoomCreationHandler(BaseHandler): (EventTypes.RoomHistoryVisibility, ""), (EventTypes.GuestAccess, ""), (EventTypes.RoomAvatar, ""), + (EventTypes.Encryption, ""), ) old_room_state_ids = yield self.store.get_filtered_current_state_ids( diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index b9c539b985..5aa7987126 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py
@@ -15,6 +15,7 @@ import logging from collections import namedtuple +from datetime import datetime, timedelta from six import PY3, iteritems from six.moves import range @@ -76,8 +77,14 @@ class RoomListHandler(BaseHandler): # We explicitly don't bother caching searches or requests for # appservice specific lists. logger.info("Bypassing cache as search request.") + + # XXX: Quick hack to stop room directory queries taking too long. + # Timeout request after 60s. Probably want a more fundamental + # solution at some point + timeout = datetime.now() + timedelta(seconds=60) return self._get_public_room_list( - limit, since_token, search_filter, network_tuple=network_tuple, + limit, since_token, search_filter, + network_tuple=network_tuple, timeout=timeout, ) key = (limit, since_token, network_tuple) @@ -90,7 +97,8 @@ class RoomListHandler(BaseHandler): @defer.inlineCallbacks def _get_public_room_list(self, limit=None, since_token=None, search_filter=None, - network_tuple=EMPTY_THIRD_PARTY_ID,): + network_tuple=EMPTY_THIRD_PARTY_ID, + timeout=None,): if since_token and since_token != "END": since_token = RoomListNextBatch.from_token(since_token) else: @@ -205,6 +213,9 @@ class RoomListHandler(BaseHandler): chunk = [] for i in range(0, len(rooms_to_scan), step): + if timeout and datetime.now() > timeout: + raise Exception("Timed out searching room directory") + batch = rooms_to_scan[i:i + step] logger.info("Processing %i rooms for result", len(batch)) yield concurrently_execute( diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index a9f062df4b..71ea2e3cee 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py
@@ -20,6 +20,7 @@ from six import iteritems from twisted.internet import defer +import synapse.metrics from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.roommember import ProfileInfo diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index f86a0b624e..cd79ebab62 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py
@@ -12,30 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections import logging -import random import re -import time - -from twisted.internet import defer -from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS -from twisted.internet.error import ConnectError -from twisted.names import client, dns -from twisted.names.error import DNSNameError, DomainError logger = logging.getLogger(__name__) -SERVER_CACHE = {} - -# our record of an individual server which can be tried to reach a destination. -# -# "host" is the hostname acquired from the SRV record. Except when there's -# no SRV record, in which case it is the original hostname. -_Server = collections.namedtuple( - "_Server", "priority weight host port expires" -) - def parse_server_name(server_name): """Split a server name into host/port parts. @@ -100,264 +81,3 @@ def parse_and_validate_server_name(server_name): )) return host, port - - -def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None, - timeout=None): - """Construct an endpoint for the given matrix destination. - - Args: - reactor: Twisted reactor. - destination (unicode): The name of the server to connect to. - tls_client_options_factory - (synapse.crypto.context_factory.ClientTLSOptionsFactory): - Factory which generates TLS options for client connections. - timeout (int): connection timeout in seconds - """ - - domain, port = parse_server_name(destination) - - endpoint_kw_args = {} - - if timeout is not None: - endpoint_kw_args.update(timeout=timeout) - - if tls_client_options_factory is None: - transport_endpoint = HostnameEndpoint - default_port = 8008 - else: - # the SNI string should be the same as the Host header, minus the port. - # as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777, - # the Host header and SNI should therefore be the server_name of the remote - # server. - tls_options = tls_client_options_factory.get_options(domain) - - def transport_endpoint(reactor, host, port, timeout): - return wrapClientTLS( - tls_options, - HostnameEndpoint(reactor, host, port, timeout=timeout), - ) - default_port = 8448 - - if port is None: - return _WrappingEndpointFac(SRVClientEndpoint( - reactor, "matrix", domain, protocol="tcp", - default_port=default_port, endpoint=transport_endpoint, - endpoint_kw_args=endpoint_kw_args - ), reactor) - else: - return _WrappingEndpointFac(transport_endpoint( - reactor, domain, port, **endpoint_kw_args - ), reactor) - - -class _WrappingEndpointFac(object): - def __init__(self, endpoint_fac, reactor): - self.endpoint_fac = endpoint_fac - self.reactor = reactor - - @defer.inlineCallbacks - def connect(self, protocolFactory): - conn = yield self.endpoint_fac.connect(protocolFactory) - conn = _WrappedConnection(conn, self.reactor) - defer.returnValue(conn) - - -class _WrappedConnection(object): - """Wraps a connection and calls abort on it if it hasn't seen any action - for 2.5-3 minutes. - """ - __slots__ = ["conn", "last_request"] - - def __init__(self, conn, reactor): - object.__setattr__(self, "conn", conn) - object.__setattr__(self, "last_request", time.time()) - self._reactor = reactor - - def __getattr__(self, name): - return getattr(self.conn, name) - - def __setattr__(self, name, value): - setattr(self.conn, name, value) - - def _time_things_out_maybe(self): - # We use a slightly shorter timeout here just in case the callLater is - # triggered early. Paranoia ftw. - # TODO: Cancel the previous callLater rather than comparing time.time()? - if time.time() - self.last_request >= 2.5 * 60: - self.abort() - # Abort the underlying TLS connection. The abort() method calls - # loseConnection() on the TLS connection which tries to - # shutdown the connection cleanly. We call abortConnection() - # since that will promptly close the TLS connection. - # - # In Twisted >18.4; the TLS connection will be None if it has closed - # which will make abortConnection() throw. Check that the TLS connection - # is not None before trying to close it. - if self.transport.getHandle() is not None: - self.transport.abortConnection() - - def request(self, request): - self.last_request = time.time() - - # Time this connection out if we haven't send a request in the last - # N minutes - # TODO: Cancel the previous callLater? - self._reactor.callLater(3 * 60, self._time_things_out_maybe) - - d = self.conn.request(request) - - def update_request_time(res): - self.last_request = time.time() - # TODO: Cancel the previous callLater? - self._reactor.callLater(3 * 60, self._time_things_out_maybe) - return res - - d.addCallback(update_request_time) - - return d - - -class SRVClientEndpoint(object): - """An endpoint which looks up SRV records for a service. - Cycles through the list of servers starting with each call to connect - picking the next server. - Implements twisted.internet.interfaces.IStreamClientEndpoint. - """ - - def __init__(self, reactor, service, domain, protocol="tcp", - default_port=None, endpoint=HostnameEndpoint, - endpoint_kw_args={}): - self.reactor = reactor - self.service_name = "_%s._%s.%s" % (service, protocol, domain) - - if default_port is not None: - self.default_server = _Server( - host=domain, - port=default_port, - priority=0, - weight=0, - expires=0, - ) - else: - self.default_server = None - - self.endpoint = endpoint - self.endpoint_kw_args = endpoint_kw_args - - self.servers = None - self.used_servers = None - - @defer.inlineCallbacks - def fetch_servers(self): - self.used_servers = [] - self.servers = yield resolve_service(self.service_name) - - def pick_server(self): - if not self.servers: - if self.used_servers: - self.servers = self.used_servers - self.used_servers = [] - self.servers.sort() - elif self.default_server: - return self.default_server - else: - raise ConnectError( - "No server available for %s" % self.service_name - ) - - # look for all servers with the same priority - min_priority = self.servers[0].priority - weight_indexes = list( - (index, server.weight + 1) - for index, server in enumerate(self.servers) - if server.priority == min_priority - ) - - total_weight = sum(weight for index, weight in weight_indexes) - target_weight = random.randint(0, total_weight) - for index, weight in weight_indexes: - target_weight -= weight - if target_weight <= 0: - server = self.servers[index] - # XXX: this looks totally dubious: - # - # (a) we never reuse a server until we have been through - # all of the servers at the same priority, so if the - # weights are A: 100, B:1, we always do ABABAB instead of - # AAAA...AAAB (approximately). - # - # (b) After using all the servers at the lowest priority, - # we move onto the next priority. We should only use the - # second priority if servers at the top priority are - # unreachable. - # - del self.servers[index] - self.used_servers.append(server) - return server - - @defer.inlineCallbacks - def connect(self, protocolFactory): - if self.servers is None: - yield self.fetch_servers() - server = self.pick_server() - logger.info("Connecting to %s:%s", server.host, server.port) - endpoint = self.endpoint( - self.reactor, server.host, server.port, **self.endpoint_kw_args - ) - connection = yield endpoint.connect(protocolFactory) - defer.returnValue(connection) - - -@defer.inlineCallbacks -def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time): - cache_entry = cache.get(service_name, None) - if cache_entry: - if all(s.expires > int(clock.time()) for s in cache_entry): - servers = list(cache_entry) - defer.returnValue(servers) - - servers = [] - - try: - try: - answers, _, _ = yield dns_client.lookupService(service_name) - except DNSNameError: - defer.returnValue([]) - - if (len(answers) == 1 - and answers[0].type == dns.SRV - and answers[0].payload - and answers[0].payload.target == dns.Name(b'.')): - raise ConnectError("Service %s unavailable" % service_name) - - for answer in answers: - if answer.type != dns.SRV or not answer.payload: - continue - - payload = answer.payload - - servers.append(_Server( - host=str(payload.target), - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight), - expires=int(clock.time()) + answer.ttl, - )) - - servers.sort() - cache[service_name] = list(servers) - except DomainError as e: - # We failed to resolve the name (other than a NameError) - # Try something in the cache, else rereaise - cache_entry = cache.get(service_name, None) - if cache_entry: - logger.warn( - "Failed to resolve %r, falling back to cache. %r", - service_name, e - ) - servers = list(cache_entry) - else: - raise e - - defer.returnValue(servers) diff --git a/synapse/http/federation/__init__.py b/synapse/http/federation/__init__.py new file mode 100644
index 0000000000..1453d04571 --- /dev/null +++ b/synapse/http/federation/__init__.py
@@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py new file mode 100644
index 0000000000..64c780a341 --- /dev/null +++ b/synapse/http/federation/matrix_federation_agent.py
@@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from zope.interface import implementer + +from twisted.internet import defer +from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS +from twisted.web.client import URI, Agent, HTTPConnectionPool +from twisted.web.iweb import IAgent + +from synapse.http.endpoint import parse_server_name +from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list +from synapse.util.logcontext import make_deferred_yieldable + +logger = logging.getLogger(__name__) + + +@implementer(IAgent) +class MatrixFederationAgent(object): + """An Agent-like thing which provides a `request` method which will look up a matrix + server and send an HTTP request to it. + + Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.) + + Args: + reactor (IReactor): twisted reactor to use for underlying requests + + tls_client_options_factory (ClientTLSOptionsFactory|None): + factory to use for fetching client tls options, or none to disable TLS. + + srv_resolver (SrvResolver|None): + SRVResolver impl to use for looking up SRV records. None to use a default + implementation. + """ + + def __init__( + self, reactor, tls_client_options_factory, _srv_resolver=None, + ): + self._reactor = reactor + self._tls_client_options_factory = tls_client_options_factory + if _srv_resolver is None: + _srv_resolver = SrvResolver() + self._srv_resolver = _srv_resolver + + self._pool = HTTPConnectionPool(reactor) + self._pool.retryAutomatically = False + self._pool.maxPersistentPerHost = 5 + self._pool.cachedConnectionTimeout = 2 * 60 + + @defer.inlineCallbacks + def request(self, method, uri, headers=None, bodyProducer=None): + """ + Args: + method (bytes): HTTP method: GET/POST/etc + + uri (bytes): Absolute URI to be retrieved + + headers (twisted.web.http_headers.Headers|None): + HTTP headers to send with the request, or None to + send no extra headers. + + bodyProducer (twisted.web.iweb.IBodyProducer|None): + An object which can generate bytes to make up the + body of this request (for example, the properly encoded contents of + a file for a file upload). Or None if the request is to have + no body. + + Returns: + Deferred[twisted.web.iweb.IResponse]: + fires when the header of the response has been received (regardless of the + response status code). Fails if there is any problem which prevents that + response from being received (including problems that prevent the request + from being sent). + """ + + parsed_uri = URI.fromBytes(uri) + server_name_bytes = parsed_uri.netloc + host, port = parse_server_name(server_name_bytes.decode("ascii")) + + # XXX disabling TLS is really only supported here for the benefit of the + # unit tests. We should make the UTs cope with TLS rather than having to make + # the code support the unit tests. + if self._tls_client_options_factory is None: + tls_options = None + else: + tls_options = self._tls_client_options_factory.get_options(host) + + if port is not None: + target = (host, port) + else: + server_list = yield self._srv_resolver.resolve_service(server_name_bytes) + if not server_list: + target = (host, 8448) + logger.debug("No SRV record for %s, using %s", host, target) + else: + target = pick_server_from_list(server_list) + + class EndpointFactory(object): + @staticmethod + def endpointForURI(_uri): + logger.info("Connecting to %s:%s", target[0], target[1]) + ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1]) + if tls_options is not None: + ep = wrapClientTLS(tls_options, ep) + return ep + + agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool) + res = yield make_deferred_yieldable( + agent.request(method, uri, headers, bodyProducer) + ) + defer.returnValue(res) diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py new file mode 100644
index 0000000000..71830c549d --- /dev/null +++ b/synapse/http/federation/srv_resolver.py
@@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import random +import time + +import attr + +from twisted.internet import defer +from twisted.internet.error import ConnectError +from twisted.names import client, dns +from twisted.names.error import DNSNameError, DomainError + +from synapse.util.logcontext import make_deferred_yieldable + +logger = logging.getLogger(__name__) + +SERVER_CACHE = {} + + +@attr.s +class Server(object): + """ + Our record of an individual server which can be tried to reach a destination. + + Attributes: + host (bytes): target hostname + port (int): + priority (int): + weight (int): + expires (int): when the cache should expire this record - in *seconds* since + the epoch + """ + host = attr.ib() + port = attr.ib() + priority = attr.ib(default=0) + weight = attr.ib(default=0) + expires = attr.ib(default=0) + + +def pick_server_from_list(server_list): + """Randomly choose a server from the server list + + Args: + server_list (list[Server]): list of candidate servers + + Returns: + Tuple[bytes, int]: (host, port) pair for the chosen server + """ + if not server_list: + raise RuntimeError("pick_server_from_list called with empty list") + + # TODO: currently we only use the lowest-priority servers. We should maintain a + # cache of servers known to be "down" and filter them out + + min_priority = min(s.priority for s in server_list) + eligible_servers = list(s for s in server_list if s.priority == min_priority) + total_weight = sum(s.weight for s in eligible_servers) + target_weight = random.randint(0, total_weight) + + for s in eligible_servers: + target_weight -= s.weight + + if target_weight <= 0: + return s.host, s.port + + # this should be impossible. + raise RuntimeError( + "pick_server_from_list got to end of eligible server list.", + ) + + +class SrvResolver(object): + """Interface to the dns client to do SRV lookups, with result caching. + + The default resolver in twisted.names doesn't do any caching (it has a CacheResolver, + but the cache never gets populated), so we add our own caching layer here. + + Args: + dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl + cache (dict): cache object + get_time (callable): clock implementation. Should return seconds since the epoch + """ + def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time): + self._dns_client = dns_client + self._cache = cache + self._get_time = get_time + + @defer.inlineCallbacks + def resolve_service(self, service_name): + """Look up a SRV record + + Args: + service_name (bytes): record to look up + + Returns: + Deferred[list[Server]]: + a list of the SRV records, or an empty list if none found + """ + now = int(self._get_time()) + + if not isinstance(service_name, bytes): + raise TypeError("%r is not a byte string" % (service_name,)) + + cache_entry = self._cache.get(service_name, None) + if cache_entry: + if all(s.expires > now for s in cache_entry): + servers = list(cache_entry) + defer.returnValue(servers) + + try: + answers, _, _ = yield make_deferred_yieldable( + self._dns_client.lookupService(service_name), + ) + except DNSNameError: + # TODO: cache this. We can get the SOA out of the exception, and use + # the negative-TTL value. + defer.returnValue([]) + except DomainError as e: + # We failed to resolve the name (other than a NameError) + # Try something in the cache, else rereaise + cache_entry = self._cache.get(service_name, None) + if cache_entry: + logger.warn( + "Failed to resolve %r, falling back to cache. %r", + service_name, e + ) + defer.returnValue(list(cache_entry)) + else: + raise e + + if (len(answers) == 1 + and answers[0].type == dns.SRV + and answers[0].payload + and answers[0].payload.target == dns.Name(b'.')): + raise ConnectError("Service %s unavailable" % service_name) + + servers = [] + + for answer in answers: + if answer.type != dns.SRV or not answer.payload: + continue + + payload = answer.payload + + servers.append(Server( + host=payload.target.name, + port=payload.port, + priority=payload.priority, + weight=payload.weight, + expires=now + answer.ttl, + )) + + self._cache[service_name] = list(servers) + defer.returnValue(servers) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index be4076fc6a..980e912348 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -32,7 +32,7 @@ from twisted.internet import defer, protocol from twisted.internet.error import DNSLookupError from twisted.internet.task import _EPSILON, Cooperator from twisted.web._newclient import ResponseDone -from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool +from twisted.web.client import FileBodyProducer from twisted.web.http_headers import Headers import synapse.metrics @@ -44,7 +44,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) -from synapse.http.endpoint import matrix_federation_endpoint +from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.util.async_helpers import timeout_deferred from synapse.util.logcontext import make_deferred_yieldable from synapse.util.metrics import Measure @@ -66,20 +66,6 @@ else: MAXINT = sys.maxint -class MatrixFederationEndpointFactory(object): - def __init__(self, hs): - self.reactor = hs.get_reactor() - self.tls_client_options_factory = hs.tls_client_options_factory - - def endpointForURI(self, uri): - destination = uri.netloc.decode('ascii') - - return matrix_federation_endpoint( - self.reactor, destination, timeout=10, - tls_client_options_factory=self.tls_client_options_factory - ) - - _next_id = 1 @@ -187,12 +173,10 @@ class MatrixFederationHttpClient(object): self.signing_key = hs.config.signing_key[0] self.server_name = hs.hostname reactor = hs.get_reactor() - pool = HTTPConnectionPool(reactor) - pool.retryAutomatically = False - pool.maxPersistentPerHost = 5 - pool.cachedConnectionTimeout = 2 * 60 - self.agent = Agent.usingEndpointFactory( - reactor, MatrixFederationEndpointFactory(hs), pool=pool + + self.agent = MatrixFederationAgent( + hs.get_reactor(), + hs.tls_client_options_factory, ) self.clock = hs.get_clock() self._store = hs.get_datastore() @@ -229,19 +213,18 @@ class MatrixFederationHttpClient(object): backoff_on_404 (bool): Back off if we get a 404 Returns: - Deferred: resolves with the http response object on success. - - Fails with ``HttpResponseException``: if we get an HTTP response - code >= 300 (except 429). - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist - - Fails with ``RequestSendFailed`` if there were problems connecting to - the remote, due to e.g. DNS failures, connection timeouts etc. + Deferred[twisted.web.client.Response]: resolves with the HTTP + response object on success. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ if timeout: _sec_timeout = timeout / 1000 @@ -299,9 +282,9 @@ class MatrixFederationHttpClient(object): json = request.get_json() if json: headers_dict[b"Content-Type"] = [b"application/json"] - self.sign_request( + auth_headers = self.build_auth_headers( destination_bytes, method_bytes, url_to_sign_bytes, - headers_dict, json, + json, ) data = encode_canonical_json(json) producer = FileBodyProducer( @@ -310,40 +293,40 @@ class MatrixFederationHttpClient(object): ) else: producer = None - self.sign_request( + auth_headers = self.build_auth_headers( destination_bytes, method_bytes, url_to_sign_bytes, - headers_dict, ) + headers_dict[b"Authorization"] = auth_headers + logger.info( - "{%s} [%s] Sending request: %s %s", + "{%s} [%s] Sending request: %s %s; timeout %fs", request.txn_id, request.destination, request.method, - url_str, - ) - - # we don't want all the fancy cookie and redirect handling that - # treq.request gives: just use the raw Agent. - request_deferred = self.agent.request( - method_bytes, - url_bytes, - headers=Headers(headers_dict), - bodyProducer=producer, - ) - - request_deferred = timeout_deferred( - request_deferred, - timeout=_sec_timeout, - reactor=self.hs.get_reactor(), + url_str, _sec_timeout, ) try: with Measure(self.clock, "outbound_request"): - response = yield make_deferred_yieldable( + # we don't want all the fancy cookie and redirect handling + # that treq.request gives: just use the raw Agent. + request_deferred = self.agent.request( + method_bytes, + url_bytes, + headers=Headers(headers_dict), + bodyProducer=producer, + ) + + request_deferred = timeout_deferred( request_deferred, + timeout=_sec_timeout, + reactor=self.hs.get_reactor(), ) + + response = yield request_deferred except DNSLookupError as e: raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e) except Exception as e: + logger.info("Failed to send request: %s", e) raise_from(RequestSendFailed(e, can_retry=True), e) logger.info( @@ -441,24 +424,23 @@ class MatrixFederationHttpClient(object): defer.returnValue(response) - def sign_request(self, destination, method, url_bytes, headers_dict, - content=None, destination_is=None): + def build_auth_headers( + self, destination, method, url_bytes, content=None, destination_is=None, + ): """ - Signs a request by adding an Authorization header to headers_dict + Builds the Authorization headers for a federation request Args: destination (bytes|None): The desination home server of the request. May be None if the destination is an identity server, in which case destination_is must be non-None. method (bytes): The HTTP method of the request url_bytes (bytes): The URI path of the request - headers_dict (dict[bytes, list[bytes]]): Dictionary of request headers to - append to content (object): The body of the request destination_is (bytes): As 'destination', but if the destination is an identity server Returns: - None + list[bytes]: a list of headers to be added as "Authorization:" headers """ request = { "method": method, @@ -485,8 +467,7 @@ class MatrixFederationHttpClient(object): self.server_name, key, sig, )).encode('ascii') ) - - headers_dict[b"Authorization"] = auth_headers + return auth_headers @defer.inlineCallbacks def put_json(self, destination, path, args={}, data={}, @@ -516,17 +497,18 @@ class MatrixFederationHttpClient(object): requests) Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. - - Fails with ``HttpResponseException`` if we get an HTTP response - code >= 300. - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist + Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + result will be the decoded JSON body. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( @@ -570,17 +552,18 @@ class MatrixFederationHttpClient(object): try the request anyway. args (dict): query params Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. - - Fails with ``HttpResponseException`` if we get an HTTP response - code >= 300. - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist + Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + result will be the decoded JSON body. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( @@ -625,17 +608,18 @@ class MatrixFederationHttpClient(object): ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. - - Fails with ``HttpResponseException`` if we get an HTTP response - code >= 300. - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist + Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + result will be the decoded JSON body. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ logger.debug("get_json args: %s", args) @@ -676,17 +660,18 @@ class MatrixFederationHttpClient(object): ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. - - Fails with ``HttpResponseException`` if we get an HTTP response - code >= 300. - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist + Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + result will be the decoded JSON body. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( method="DELETE", @@ -719,18 +704,20 @@ class MatrixFederationHttpClient(object): args (dict): Optional dictionary used to create the query string. ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. - Returns: - Deferred: resolves with an (int,dict) tuple of the file length and - a dict of the response headers. - - Fails with ``HttpResponseException`` if we get an HTTP response code - >= 300 - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist + Returns: + Deferred[tuple[int, dict]]: Resolves with an (int,dict) tuple of + the file length and a dict of the response headers. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( method="GET", diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 69c5f9fe2e..756721e304 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py
@@ -40,7 +40,11 @@ REQUIREMENTS = [ "signedjson>=1.0.0", "pynacl>=1.2.1", "service_identity>=16.0.0", - "Twisted>=17.1.0", + + # our logcontext handling relies on the ability to cancel inlineCallbacks + # (https://twistedmatrix.com/trac/ticket/4632) which landed in Twisted 18.7. + "Twisted>=18.7.0", + "treq>=15.1", # Twisted has required pyopenssl 16.0 since about Twisted 16.6. "pyopenssl>=16.0.0", @@ -52,15 +56,18 @@ REQUIREMENTS = [ "pillow>=3.1.2", "sortedcontainers>=1.4.4", "psutil>=2.0.0", - "pymacaroons-pynacl>=0.9.3", - "msgpack-python>=0.4.2", + "pymacaroons>=0.13.0", + "msgpack>=0.5.0", "phonenumbers>=8.2.0", "six>=1.10", # prometheus_client 0.4.0 changed the format of counter metrics # (cf https://github.com/matrix-org/synapse/issues/4001) "prometheus_client>=0.0.18,<0.4.0", + # we use attr.s(slots), which arrived in 16.0.0 - "attrs>=16.0.0", + # Twisted 18.7.0 requires attrs>=17.4.0 + "attrs>=17.4.0", + "netaddr>=0.7.18", ] @@ -72,6 +79,10 @@ CONDITIONAL_REQUIREMENTS = { # ConsentResource uses select_autoescape, which arrived in jinja 2.9 "resources.consent": ["Jinja2>=2.9"], + # ACME support is required to provision TLS certificates from authorities + # that use the protocol, such as Let's Encrypt. + "acme": ["txacme>=0.9.2"], + "saml2": ["pysaml2>=4.5.0"], "url_preview": ["lxml>=3.5.0"], "test": ["mock>=2.0"], diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index aec0c6b075..14025cd219 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py
@@ -309,22 +309,16 @@ class RegisterRestServlet(RestServlet): assigned_user_id=registered_user_id, ) - # Only give msisdn flows if the x_show_msisdn flag is given: - # this is a hack to work around the fact that clients were shipped - # that use fallback registration if they see any flows that they don't - # recognise, which means we break registration for these clients if we - # advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot - # Android <=0.6.9 have fallen below an acceptable threshold, this - # parameter should go away and we should always advertise msisdn flows. - show_msisdn = False - if 'x_show_msisdn' in body and body['x_show_msisdn']: - show_msisdn = True - # FIXME: need a better error than "no auth flow found" for scenarios # where we required 3PID for registration but the user didn't give one require_email = 'email' in self.hs.config.registrations_require_3pid require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid + show_msisdn = True + if self.hs.config.disable_msisdn_registration: + show_msisdn = False + require_msisdn = False + flows = [] if self.hs.config.enable_registration_captcha: # only support 3PIDless registration if no 3PIDs are required diff --git a/synapse/server.py b/synapse/server.py
index 9985687b95..c8914302cf 100644 --- a/synapse/server.py +++ b/synapse/server.py
@@ -46,6 +46,7 @@ from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.groups_server import GroupsServerHandler from synapse.handlers import Handlers +from synapse.handlers.acme import AcmeHandler from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler, MacaroonGenerator from synapse.handlers.deactivate_account import DeactivateAccountHandler @@ -129,6 +130,7 @@ class HomeServer(object): 'sync_handler', 'typing_handler', 'room_list_handler', + 'acme_handler', 'auth_handler', 'device_handler', 'e2e_keys_handler', @@ -310,6 +312,9 @@ class HomeServer(object): def build_e2e_room_keys_handler(self): return E2eRoomKeysHandler(self) + def build_acme_handler(self): + return AcmeHandler(self) + def build_application_service_api(self): return ApplicationServiceApi(self) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 1d3069b143..254fdc04c6 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py
@@ -192,6 +192,41 @@ class SQLBaseStore(object): self.database_engine = hs.database_engine + # A set of tables that are not safe to use native upserts in. + self._unsafe_to_upsert_tables = {"user_ips"} + + if self.database_engine.can_native_upsert: + # Check ASAP (and then later, every 1s) to see if we have finished + # background updates of tables that aren't safe to update. + self._clock.call_later(0.0, self._check_safe_to_upsert) + + @defer.inlineCallbacks + def _check_safe_to_upsert(self): + """ + Is it safe to use native UPSERT? + + If there are background updates, we will need to wait, as they may be + the addition of indexes that set the UNIQUE constraint that we require. + + If the background updates have not completed, wait a second and check again. + """ + updates = yield self._simple_select_list( + "background_updates", + keyvalues=None, + retcols=["update_name"], + desc="check_background_updates", + ) + updates = [x["update_name"] for x in updates] + + # The User IPs table in schema #53 was missing a unique index, which we + # run as a background update. + if "user_ips_device_unique_index" not in updates: + self._unsafe_to_upsert_tables.discard("user_id") + + # If there's any tables left to check, reschedule to run. + if self._unsafe_to_upsert_tables: + self._clock.call_later(1.0, self._check_safe_to_upsert) + def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() @@ -494,8 +529,15 @@ class SQLBaseStore(object): txn.executemany(sql, vals) @defer.inlineCallbacks - def _simple_upsert(self, table, keyvalues, values, - insertion_values={}, desc="_simple_upsert", lock=True): + def _simple_upsert( + self, + table, + keyvalues, + values, + insertion_values={}, + desc="_simple_upsert", + lock=True + ): """ `lock` should generally be set to True (the default), but can be set @@ -516,16 +558,21 @@ class SQLBaseStore(object): inserting lock (bool): True to lock the table when doing the upsert. Returns: - Deferred(bool): True if a new entry was created, False if an - existing one was updated. + Deferred(None or bool): Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. """ attempts = 0 while True: try: result = yield self.runInteraction( desc, - self._simple_upsert_txn, table, keyvalues, values, insertion_values, - lock=lock + self._simple_upsert_txn, + table, + keyvalues, + values, + insertion_values, + lock=lock, ) defer.returnValue(result) except self.database_engine.module.IntegrityError as e: @@ -537,21 +584,76 @@ class SQLBaseStore(object): # presumably we raced with another transaction: let's retry. logger.warn( - "IntegrityError when upserting into %s; retrying: %s", - table, e + "%s when upserting into %s; retrying: %s", e.__name__, table, e ) - def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}, - lock=True): + def _simple_upsert_txn( + self, + txn, + table, + keyvalues, + values, + insertion_values={}, + lock=True, + ): + """ + Pick the UPSERT method which works best on the platform. Either the + native one (Pg9.5+, recent SQLites), or fall back to an emulated method. + + Args: + txn: The transaction to use. + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + Deferred(None or bool): Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + if ( + self.database_engine.can_native_upsert + and table not in self._unsafe_to_upsert_tables + ): + return self._simple_upsert_txn_native_upsert( + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + ) + else: + return self._simple_upsert_txn_emulated( + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + lock=lock, + ) + + def _simple_upsert_txn_emulated( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): # We need to lock the table :(, unless we're *really* careful if lock: self.database_engine.lock_table(txn, table) + def _getwhere(key): + # If the value we're passing in is None (aka NULL), we need to use + # IS, not =, as NULL = NULL equals NULL (False). + if keyvalues[key] is None: + return "%s IS ?" % (key,) + else: + return "%s = ?" % (key,) + # First try to update. sql = "UPDATE %s SET %s WHERE %s" % ( table, ", ".join("%s = ?" % (k,) for k in values), - " AND ".join("%s = ?" % (k,) for k in keyvalues) + " AND ".join(_getwhere(k) for k in keyvalues) ) sqlargs = list(values.values()) + list(keyvalues.values()) @@ -569,12 +671,44 @@ class SQLBaseStore(object): sql = "INSERT INTO %s (%s) VALUES (%s)" % ( table, ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues) + ", ".join("?" for _ in allvalues), ) txn.execute(sql, list(allvalues.values())) # successfully inserted return True + def _simple_upsert_txn_native_upsert( + self, txn, table, keyvalues, values, insertion_values={} + ): + """ + Use the native UPSERT functionality in recent PostgreSQL versions. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + Returns: + None + """ + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + allvalues.update(insertion_values) + + sql = ( + "INSERT INTO %s (%s) VALUES (%s) " + "ON CONFLICT (%s) DO UPDATE SET %s" + ) % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ", ".join(k for k in keyvalues), + ", ".join(k + "=EXCLUDED." + k for k in values), + ) + txn.execute(sql, list(allvalues.values())) + def _simple_select_one(self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"): """Executes a SELECT query on the named table, which is expected to diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 11513af2ec..4a221d90ec 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py
@@ -65,7 +65,27 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): columns=["last_seen"], ) - # (user_id, access_token, ip) -> (user_agent, device_id, last_seen) + self.register_background_update_handler( + "user_ips_remove_dupes", + self._remove_user_ip_dupes, + ) + + # Register a unique index + self.register_background_index_update( + "user_ips_device_unique_index", + index_name="user_ips_user_token_ip_unique_index", + table="user_ips", + columns=["user_id", "access_token", "ip"], + unique=True, + ) + + # Drop the old non-unique index + self.register_background_update_handler( + "user_ips_drop_nonunique_index", + self._remove_user_ip_nonunique, + ) + + # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) self._batch_row_update = {} self._client_ip_looper = self._clock.looping_call( @@ -76,6 +96,129 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): ) @defer.inlineCallbacks + def _remove_user_ip_nonunique(self, progress, batch_size): + def f(conn): + txn = conn.cursor() + txn.execute( + "DROP INDEX IF EXISTS user_ips_user_ip" + ) + txn.close() + + yield self.runWithConnection(f) + yield self._end_background_update("user_ips_drop_nonunique_index") + defer.returnValue(1) + + @defer.inlineCallbacks + def _remove_user_ip_dupes(self, progress, batch_size): + # This works function works by scanning the user_ips table in batches + # based on `last_seen`. For each row in a batch it searches the rest of + # the table to see if there are any duplicates, if there are then they + # are removed and replaced with a suitable row. + + # Fetch the start of the batch + begin_last_seen = progress.get("last_seen", 0) + + def get_last_seen(txn): + txn.execute( + """ + SELECT last_seen FROM user_ips + WHERE last_seen > ? + ORDER BY last_seen + LIMIT 1 + OFFSET ? + """, + (begin_last_seen, batch_size) + ) + row = txn.fetchone() + if row: + return row[0] + else: + return None + + # Get a last seen that has roughly `batch_size` since `begin_last_seen` + end_last_seen = yield self.runInteraction( + "user_ips_dups_get_last_seen", get_last_seen + ) + + # If it returns None, then we're processing the last batch + last = end_last_seen is None + + logger.info( + "Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s", + begin_last_seen, end_last_seen, + ) + + def remove(txn): + # This works by looking at all entries in the given time span, and + # then for each (user_id, access_token, ip) tuple in that range + # checking for any duplicates in the rest of the table (via a join). + # It then only returns entries which have duplicates, and the max + # last_seen across all duplicates, which can the be used to delete + # all other duplicates. + # It is efficient due to the existence of (user_id, access_token, + # ip) and (last_seen) indices. + + # Define the search space, which requires handling the last batch in + # a different way + if last: + clause = "? <= last_seen" + args = (begin_last_seen,) + else: + clause = "? <= last_seen AND last_seen < ?" + args = (begin_last_seen, end_last_seen) + + txn.execute( + """ + SELECT user_id, access_token, ip, + MAX(device_id), MAX(user_agent), MAX(last_seen) + FROM ( + SELECT user_id, access_token, ip + FROM user_ips + WHERE {} + ) c + INNER JOIN user_ips USING (user_id, access_token, ip) + GROUP BY user_id, access_token, ip + HAVING count(*) > 1 + """.format(clause), + args + ) + res = txn.fetchall() + + # We've got some duplicates + for i in res: + user_id, access_token, ip, device_id, user_agent, last_seen = i + + # Drop all the duplicates + txn.execute( + """ + DELETE FROM user_ips + WHERE user_id = ? AND access_token = ? AND ip = ? + """, + (user_id, access_token, ip) + ) + + # Add in one to be the last_seen + txn.execute( + """ + INSERT INTO user_ips + (user_id, access_token, ip, device_id, user_agent, last_seen) + VALUES (?, ?, ?, ?, ?, ?) + """, + (user_id, access_token, ip, device_id, user_agent, last_seen) + ) + + self._background_update_progress_txn( + txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} + ) + + yield self.runInteraction("user_ips_dups_remove", remove) + + if last: + yield self._end_background_update("user_ips_remove_dupes") + + defer.returnValue(batch_size) + + @defer.inlineCallbacks def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id, now=None): if not now: @@ -114,7 +257,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): ) def _update_client_ips_batch_txn(self, txn, to_update): - self.database_engine.lock_table(txn, "user_ips") + if "user_ips" in self._unsafe_to_upsert_tables or ( + not self.database_engine.can_native_upsert + ): + self.database_engine.lock_table(txn, "user_ips") for entry in iteritems(to_update): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry @@ -127,10 +273,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): "user_id": user_id, "access_token": access_token, "ip": ip, - "user_agent": user_agent, - "device_id": device_id, }, values={ + "user_agent": user_agent, + "device_id": device_id, "last_seen": last_seen, }, lock=False, @@ -227,7 +373,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): results = {} for key in self._batch_row_update: - uid, access_token, ip = key + uid, access_token, ip, = key if uid == user_id: user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index e2f9de8451..ff5ef97ca8 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py
@@ -18,7 +18,7 @@ import platform from ._base import IncorrectDatabaseSetup from .postgres import PostgresEngine -from .sqlite3 import Sqlite3Engine +from .sqlite import Sqlite3Engine SUPPORTED_MODULE = { "sqlite3": Sqlite3Engine, diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 42225f8a2a..4004427c7b 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py
@@ -38,6 +38,13 @@ class PostgresEngine(object): return sql.replace("?", "%s") def on_new_connection(self, db_conn): + + # Get the version of PostgreSQL that we're using. As per the psycopg2 + # docs: The number is formed by converting the major, minor, and + # revision numbers into two-decimal-digit numbers and appending them + # together. For example, version 8.1.5 will be returned as 80105 + self._version = db_conn.server_version + db_conn.set_isolation_level( self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) @@ -54,6 +61,13 @@ class PostgresEngine(object): cursor.close() + @property + def can_native_upsert(self): + """ + Can we use native UPSERTs? This requires PostgreSQL 9.5+. + """ + return self._version >= 90500 + def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite.py
index 19949fc474..c64d73ff21 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite.py
@@ -15,6 +15,7 @@ import struct import threading +from sqlite3 import sqlite_version_info from synapse.storage.prepare_database import prepare_database @@ -30,6 +31,14 @@ class Sqlite3Engine(object): self._current_state_group_id = None self._current_state_group_id_lock = threading.Lock() + @property + def can_native_upsert(self): + """ + Do we support native UPSERTs? This requires SQLite3 3.24+, plus some + more work we haven't done yet to tell what was inserted vs updated. + """ + return sqlite_version_info >= (3, 24, 0) + def check_database(self, txn): pass diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 2047110b1d..79e0276de6 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py
@@ -739,7 +739,18 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore } events_map = {ev.event_id: ev for ev, _ in events_context} - room_version = yield self.get_room_version(room_id) + + # We need to get the room version, which is in the create event. + # Normally that'd be in the database, but its also possible that we're + # currently trying to persist it. + room_version = None + for ev, _ in events_context: + if ev.type == EventTypes.Create and ev.state_key == "": + room_version = ev.content.get("room_version", "1") + break + + if not room_version: + room_version = yield self.get_room_version(room_id) logger.debug("calling resolve_state_groups from preserve_events") res = yield self._state_resolution_handler.resolve_state_groups( diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 2743b52bad..134297e284 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py
@@ -215,7 +215,7 @@ class PusherStore(PusherWorkerStore): with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so _simple_upsert will retry - newly_inserted = yield self._simple_upsert( + yield self._simple_upsert( table="pushers", keyvalues={ "app_id": app_id, @@ -238,7 +238,12 @@ class PusherStore(PusherWorkerStore): lock=False, ) - if newly_inserted: + user_has_pusher = self.get_if_user_has_pusher.cache.get( + (user_id,), None, update_metrics=False + ) + + if user_has_pusher is not True: + # invalidate, since we the user might not have had a pusher before yield self.runInteraction( "add_pusher", self._invalidate_cache_and_stream, diff --git a/synapse/storage/schema/delta/53/user_ips_index.sql b/synapse/storage/schema/delta/53/user_ips_index.sql new file mode 100644
index 0000000000..4ca346c111 --- /dev/null +++ b/synapse/storage/schema/delta/53/user_ips_index.sql
@@ -0,0 +1,26 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- delete duplicates +INSERT INTO background_updates (update_name, progress_json) VALUES + ('user_ips_remove_dupes', '{}'); + +-- add a new unique index to user_ips table +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('user_ips_device_unique_index', '{}', 'user_ips_remove_dupes'); + +-- drop the old original index +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('user_ips_drop_nonunique_index', '{}', 'user_ips_device_unique_index'); \ No newline at end of file diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index a8781b0e5d..ce48212265 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py
@@ -168,14 +168,14 @@ class UserDirectoryStore(SQLBaseStore): if isinstance(self.database_engine, PostgresEngine): # We weight the localpart most highly, then display name and finally # server name - if new_entry: + if self.database_engine.can_native_upsert: sql = """ INSERT INTO user_directory_search(user_id, vector) VALUES (?, setweight(to_tsvector('english', ?), 'A') || setweight(to_tsvector('english', ?), 'D') || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - ) + ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector """ txn.execute( sql, @@ -185,20 +185,45 @@ class UserDirectoryStore(SQLBaseStore): ) ) else: - sql = """ - UPDATE user_directory_search - SET vector = setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - WHERE user_id = ? - """ - txn.execute( - sql, - ( - get_localpart_from_id(user_id), get_domain_from_id(user_id), - display_name, user_id, + # TODO: Remove this code after we've bumped the minimum version + # of postgres to always support upserts, so we can get rid of + # `new_entry` usage + if new_entry is True: + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) + """ + txn.execute( + sql, + ( + user_id, get_localpart_from_id(user_id), + get_domain_from_id(user_id), display_name, + ) + ) + elif new_entry is False: + sql = """ + UPDATE user_directory_search + SET vector = setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + WHERE user_id = ? + """ + txn.execute( + sql, + ( + get_localpart_from_id(user_id), + get_domain_from_id(user_id), + display_name, user_id, + ) + ) + else: + raise RuntimeError( + "upsert returned None when 'can_native_upsert' is False" ) - ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name,) if display_name else user_id self._simple_upsert_txn( diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index ec7b2c9672..430bb15f51 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -387,12 +387,14 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): deferred that wraps and times out the given deferred, correctly handling the case where the given deferred's canceller throws. + (See https://twistedmatrix.com/trac/ticket/9534) + NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred Args: deferred (Deferred) timeout (float): Timeout in seconds - reactor (twisted.internet.reactor): The twisted reactor to use + reactor (twisted.interfaces.IReactorTime): The twisted reactor to use on_timeout_cancel (callable): A callable which is called immediately after the deferred times out, and not if this deferred is otherwise cancelled before the timeout.