diff --git a/.coveragerc b/.coveragerc
index 9873a30738..e9460a340a 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -1,11 +1,7 @@
[run]
branch = True
parallel = True
-source = synapse
-
-[paths]
-source=
- coverage
+include = synapse/*
[report]
precision = 2
diff --git a/.gitignore b/.gitignore
index d739595c3a..1033124f1d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -25,9 +25,9 @@ homeserver*.pid
*.tls.dh
*.tls.key
-.coverage
-.coverage.*
-!.coverage.rc
+.coverage*
+coverage.*
+!.coveragerc
htmlcov
demo/*/*.db
diff --git a/changelog.d/4306.misc b/changelog.d/4306.misc
new file mode 100644
index 0000000000..58130b6190
--- /dev/null
+++ b/changelog.d/4306.misc
@@ -0,0 +1 @@
+Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+.
diff --git a/changelog.d/4384.feature b/changelog.d/4384.feature
new file mode 100644
index 0000000000..daedcd58c4
--- /dev/null
+++ b/changelog.d/4384.feature
@@ -0,0 +1 @@
+Synapse can now automatically provision TLS certificates via ACME (the protocol used by CAs like Let's Encrypt).
diff --git a/changelog.d/4428.misc b/changelog.d/4428.misc
new file mode 100644
index 0000000000..9a51434755
--- /dev/null
+++ b/changelog.d/4428.misc
@@ -0,0 +1 @@
+Move SRV logic into the Agent layer
diff --git a/changelog.d/4432.misc b/changelog.d/4432.misc
new file mode 100644
index 0000000000..047061ed3c
--- /dev/null
+++ b/changelog.d/4432.misc
@@ -0,0 +1 @@
+Apply a unique index to the user_ips table, preventing duplicates.
diff --git a/changelog.d/4433.misc b/changelog.d/4433.misc
new file mode 100644
index 0000000000..30f2912db2
--- /dev/null
+++ b/changelog.d/4433.misc
@@ -0,0 +1 @@
+debian package: symlink to explicit python version
diff --git a/changelog.d/4434.misc b/changelog.d/4434.misc
new file mode 100644
index 0000000000..047061ed3c
--- /dev/null
+++ b/changelog.d/4434.misc
@@ -0,0 +1 @@
+Apply a unique index to the user_ips table, preventing duplicates.
diff --git a/changelog.d/4445.feature b/changelog.d/4445.feature
new file mode 100644
index 0000000000..a6f9b7bbac
--- /dev/null
+++ b/changelog.d/4445.feature
@@ -0,0 +1 @@
+Add a metric for tracking event stream position of the user directory.
\ No newline at end of file
diff --git a/changelog.d/4452.bugfix b/changelog.d/4452.bugfix
new file mode 100644
index 0000000000..a715ca3788
--- /dev/null
+++ b/changelog.d/4452.bugfix
@@ -0,0 +1 @@
+Don't send IP addresses as SNI
diff --git a/changelog.d/4458.misc b/changelog.d/4458.misc
new file mode 100644
index 0000000000..8b3bc94a34
--- /dev/null
+++ b/changelog.d/4458.misc
@@ -0,0 +1 @@
+Clarify documentation for the `public_baseurl` config param
diff --git a/changelog.d/4459.misc b/changelog.d/4459.misc
new file mode 100644
index 0000000000..58130b6190
--- /dev/null
+++ b/changelog.d/4459.misc
@@ -0,0 +1 @@
+Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+.
diff --git a/debian/build_virtualenv b/debian/build_virtualenv
index 83346c40f1..8b51b9e074 100755
--- a/debian/build_virtualenv
+++ b/debian/build_virtualenv
@@ -6,7 +6,16 @@
set -e
export DH_VIRTUALENV_INSTALL_ROOT=/opt/venvs
-SNAKE=/usr/bin/python3
+
+# make sure that the virtualenv links to the specific version of python, by
+# dereferencing the python3 symlink.
+#
+# Otherwise, if somebody tries to install (say) the stretch package on buster,
+# they will get a confusing error about "No module named 'synapse'", because
+# python won't look in the right directory. At least this way, the error will
+# be a *bit* more obvious.
+#
+SNAKE=`readlink -e /usr/bin/python3`
# try to set the CFLAGS so any compiled C extensions are compiled with the most
# generic as possible x64 instructions, so that compiling it on a new Intel chip
@@ -46,3 +55,7 @@ cp -r tests "$tmpdir"
PYTHONPATH="$tmpdir" \
debian/matrix-synapse-py3/opt/venvs/matrix-synapse/bin/python \
-B -m twisted.trial --reporter=text -j2 tests
+
+# add a dependency on the right version of python to substvars.
+PYPKG=`basename $SNAKE`
+echo "synapse:pydepends=$PYPKG" >> debian/matrix-synapse-py3.substvars
diff --git a/debian/control b/debian/control
index b85e36c6ca..4abfa02051 100644
--- a/debian/control
+++ b/debian/control
@@ -27,8 +27,8 @@ Depends:
adduser,
debconf,
python3-distutils|libpython3-stdlib (<< 3.6),
- python3,
${misc:Depends},
+ ${synapse:pydepends},
# some of our scripts use perl, but none of them are important,
# so we put perl:Depends in Suggests rather than Depends.
Suggests:
diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages
index 577d93e6f6..6b9be99060 100755
--- a/scripts-dev/build_debian_packages
+++ b/scripts-dev/build_debian_packages
@@ -10,12 +10,12 @@
# can be passed on the commandline for debugging.
import argparse
-from concurrent.futures import ThreadPoolExecutor
import os
import signal
import subprocess
import sys
import threading
+from concurrent.futures import ThreadPoolExecutor
DISTS = (
"debian:stretch",
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index f3ac3d19f0..ffc49d77cc 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -13,10 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import gc
import logging
import os
import sys
+import traceback
from six import iteritems
@@ -324,17 +326,12 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
@@ -361,12 +358,53 @@ def setup(config_options):
logger.info("Database prepared in %s.", config.database_config['name'])
hs.setup()
- hs.start_listening()
+ @defer.inlineCallbacks
def start():
- hs.get_pusherpool().start()
- hs.get_datastore().start_profiling()
- hs.get_datastore().start_doing_background_updates()
+ try:
+ # Check if the certificate is still valid.
+ cert_days_remaining = hs.config.is_disk_cert_valid()
+
+ if hs.config.acme_enabled:
+ # If ACME is enabled, we might need to provision a certificate
+ # before starting.
+ acme = hs.get_acme_handler()
+
+ # Start up the webservices which we will respond to ACME
+ # challenges with.
+ yield acme.start_listening()
+
+ # We want to reprovision if cert_days_remaining is None (meaning no
+ # certificate exists), or the days remaining number it returns
+ # is less than our re-registration threshold.
+ if (cert_days_remaining is None) or (
+ not cert_days_remaining > hs.config.acme_reprovision_threshold
+ ):
+ yield acme.provision_certificate()
+
+ # Read the certificate from disk and build the context factories for
+ # TLS.
+ hs.config.read_certificate_from_disk()
+ hs.tls_server_context_factory = context_factory.ServerContextFactory(config)
+ hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
+ config
+ )
+
+ # It is now safe to start your Synapse.
+ hs.start_listening()
+ hs.get_pusherpool().start()
+ hs.get_datastore().start_profiling()
+ hs.get_datastore().start_doing_background_updates()
+ except Exception as e:
+ # If a DeferredList failed (like in listening on the ACME listener),
+ # we need to print the subfailure explicitly.
+ if isinstance(e, defer.FirstError):
+ e.subFailure.printTraceback(sys.stderr)
+ sys.exit(1)
+
+ # Something else went wrong when starting. Print it and bail out.
+ traceback.print_exc(file=sys.stderr)
+ sys.exit(1)
reactor.callWhenRunning(start)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index fd2d6d52ef..5858fb92b4 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -367,7 +367,7 @@ class Config(object):
if not keys_directory:
keys_directory = os.path.dirname(config_files[-1])
- config_dir_path = os.path.abspath(keys_directory)
+ self.config_dir_path = os.path.abspath(keys_directory)
specified_config = {}
for config_file in config_files:
@@ -379,7 +379,7 @@ class Config(object):
server_name = specified_config["server_name"]
config_string = self.generate_config(
- config_dir_path=config_dir_path,
+ config_dir_path=self.config_dir_path,
data_dir_path=os.getcwd(),
server_name=server_name,
generate_secrets=False,
diff --git a/synapse/config/server.py b/synapse/config/server.py
index a915bb8b64..22dcc87d8a 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -256,7 +256,11 @@ class ServerConfig(Config):
#
# web_client_location: "/path/to/web/root"
- # The public-facing base URL for the client API (not including _matrix/...)
+ # The public-facing base URL that clients use to access this HS
+ # (not including _matrix/...). This is the same URL a user would
+ # enter into the 'custom HS URL' field on their client. If you
+ # use synapse with a reverse proxy, this should be the URL to reach
+ # synapse via the proxy.
# public_baseurl: https://example.com:8448/
# Set the soft limit on the number of file descriptors synapse can use
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index bb8952c672..a75e233aa0 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -13,60 +13,110 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import os
+from datetime import datetime
from hashlib import sha256
from unpaddedbase64 import encode_base64
from OpenSSL import crypto
-from ._base import Config
+from synapse.config._base import Config
+
+logger = logging.getLogger()
class TlsConfig(Config):
def read_config(self, config):
- self.tls_certificate = self.read_tls_certificate(
- config.get("tls_certificate_path")
- )
- self.tls_certificate_file = config.get("tls_certificate_path")
+ acme_config = config.get("acme", {})
+ self.acme_enabled = acme_config.get("enabled", False)
+ self.acme_url = acme_config.get(
+ "url", "https://acme-v01.api.letsencrypt.org/directory"
+ )
+ self.acme_port = acme_config.get("port", 8449)
+ self.acme_bind_addresses = acme_config.get("bind_addresses", ["127.0.0.1"])
+ self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
+
+ self.tls_certificate_file = os.path.abspath(config.get("tls_certificate_path"))
+ self.tls_private_key_file = os.path.abspath(config.get("tls_private_key_path"))
+ self._original_tls_fingerprints = config["tls_fingerprints"]
+ self.tls_fingerprints = list(self._original_tls_fingerprints)
self.no_tls = config.get("no_tls", False)
- if self.no_tls:
- self.tls_private_key = None
- else:
- self.tls_private_key = self.read_tls_private_key(
- config.get("tls_private_key_path")
- )
+ # This config option applies to non-federation HTTP clients
+ # (e.g. for talking to recaptcha, identity servers, and such)
+ # It should never be used in production, and is intended for
+ # use only when running tests.
+ self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
+ "use_insecure_ssl_client_just_for_testing_do_not_use"
+ )
+
+ self.tls_certificate = None
+ self.tls_private_key = None
+
+ def is_disk_cert_valid(self):
+ """
+ Is the certificate we have on disk valid, and if so, for how long?
+
+ Returns:
+ int: Days remaining of certificate validity.
+ None: No certificate exists.
+ """
+ if not os.path.exists(self.tls_certificate_file):
+ return None
+
+ try:
+ with open(self.tls_certificate_file, 'rb') as f:
+ cert_pem = f.read()
+ except Exception:
+ logger.exception("Failed to read existing certificate off disk!")
+ raise
+
+ try:
+ tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
+ except Exception:
+ logger.exception("Failed to parse existing certificate off disk!")
+ raise
+
+ # YYYYMMDDhhmmssZ -- in UTC
+ expires_on = datetime.strptime(
+ tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ"
+ )
+ now = datetime.utcnow()
+ days_remaining = (expires_on - now).days
+ return days_remaining
- self.tls_fingerprints = config["tls_fingerprints"]
+ def read_certificate_from_disk(self):
+ """
+ Read the certificates from disk.
+ """
+ self.tls_certificate = self.read_tls_certificate(self.tls_certificate_file)
+
+ if not self.no_tls:
+ self.tls_private_key = self.read_tls_private_key(self.tls_private_key_file)
+
+ self.tls_fingerprints = list(self._original_tls_fingerprints)
# Check that our own certificate is included in the list of fingerprints
# and include it if it is not.
x509_certificate_bytes = crypto.dump_certificate(
- crypto.FILETYPE_ASN1,
- self.tls_certificate
+ crypto.FILETYPE_ASN1, self.tls_certificate
)
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
if sha256_fingerprint not in sha256_fingerprints:
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
- # This config option applies to non-federation HTTP clients
- # (e.g. for talking to recaptcha, identity servers, and such)
- # It should never be used in production, and is intended for
- # use only when running tests.
- self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
- "use_insecure_ssl_client_just_for_testing_do_not_use"
- )
-
def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt"
tls_private_key_path = base_key_name + ".tls.key"
- return """\
+ return (
+ """\
# PEM encoded X509 certificate for TLS.
# You can replace the self-signed certificate that synapse
# autogenerates on launch with your own SSL certificate + key pair
@@ -107,7 +157,24 @@ class TlsConfig(Config):
#
tls_fingerprints: []
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
- """ % locals()
+
+ ## Support for ACME certificate auto-provisioning.
+ # acme:
+ # enabled: false
+ ## ACME path.
+ ## If you only want to test, use the staging url:
+ ## https://acme-staging.api.letsencrypt.org/directory
+ # url: 'https://acme-v01.api.letsencrypt.org/directory'
+ ## Port number (to listen for the HTTP-01 challenge).
+ ## Using port 80 requires utilising something like authbind, or proxying to it.
+ # port: 8449
+ ## Hosts to bind to.
+ # bind_addresses: ['127.0.0.1']
+ ## How many days remaining on a certificate before it is renewed.
+ # reprovision_threshold: 30
+ """
+ % locals()
+ )
def read_tls_certificate(self, cert_path):
cert_pem = self.read_file(cert_path, "tls_certificate")
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 6ba3eca7b2..286ad80100 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -17,6 +17,7 @@ from zope.interface import implementer
from OpenSSL import SSL, crypto
from twisted.internet._sslverify import _defaultCurveName
+from twisted.internet.abstract import isIPAddress, isIPv6Address
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
from twisted.internet.ssl import CertificateOptions, ContextFactory
from twisted.python.failure import Failure
@@ -98,8 +99,14 @@ class ClientTLSOptions(object):
def __init__(self, hostname, ctx):
self._ctx = ctx
- self._hostname = hostname
- self._hostnameBytes = _idnaBytes(hostname)
+
+ if isIPAddress(hostname) or isIPv6Address(hostname):
+ self._hostnameBytes = hostname.encode('ascii')
+ self._sendSNI = False
+ else:
+ self._hostnameBytes = _idnaBytes(hostname)
+ self._sendSNI = True
+
ctx.set_info_callback(
_tolerateErrors(self._identityVerifyingInfoCallback)
)
@@ -111,7 +118,9 @@ class ClientTLSOptions(object):
return connection
def _identityVerifyingInfoCallback(self, connection, where, ret):
- if where & SSL.SSL_CB_HANDSHAKE_START:
+ # Literal IPv4 and IPv6 addresses are not permitted
+ # as host names according to the RFCs
+ if where & SSL.SSL_CB_HANDSHAKE_START and self._sendSNI:
connection.set_tlsext_host_name(self._hostnameBytes)
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
new file mode 100644
index 0000000000..73ea7ed018
--- /dev/null
+++ b/synapse/handlers/acme.py
@@ -0,0 +1,147 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+import attr
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.internet.endpoints import serverFromString
+from twisted.python.filepath import FilePath
+from twisted.python.url import URL
+from twisted.web import server, static
+from twisted.web.resource import Resource
+
+logger = logging.getLogger(__name__)
+
+try:
+ from txacme.interfaces import ICertificateStore
+
+ @attr.s
+ @implementer(ICertificateStore)
+ class ErsatzStore(object):
+ """
+ A store that only stores in memory.
+ """
+
+ certs = attr.ib(default=attr.Factory(dict))
+
+ def store(self, server_name, pem_objects):
+ self.certs[server_name] = [o.as_bytes() for o in pem_objects]
+ return defer.succeed(None)
+
+
+except ImportError:
+ # txacme is missing
+ pass
+
+
+class AcmeHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.reactor = hs.get_reactor()
+
+ @defer.inlineCallbacks
+ def start_listening(self):
+
+ # Configure logging for txacme, if you need to debug
+ # from eliot import add_destinations
+ # from eliot.twisted import TwistedDestination
+ #
+ # add_destinations(TwistedDestination())
+
+ from txacme.challenges import HTTP01Responder
+ from txacme.service import AcmeIssuingService
+ from txacme.endpoint import load_or_create_client_key
+ from txacme.client import Client
+ from josepy.jwa import RS256
+
+ self._store = ErsatzStore()
+ responder = HTTP01Responder()
+
+ self._issuer = AcmeIssuingService(
+ cert_store=self._store,
+ client_creator=(
+ lambda: Client.from_url(
+ reactor=self.reactor,
+ url=URL.from_text(self.hs.config.acme_url),
+ key=load_or_create_client_key(
+ FilePath(self.hs.config.config_dir_path)
+ ),
+ alg=RS256,
+ )
+ ),
+ clock=self.reactor,
+ responders=[responder],
+ )
+
+ well_known = Resource()
+ well_known.putChild(b'acme-challenge', responder.resource)
+ responder_resource = Resource()
+ responder_resource.putChild(b'.well-known', well_known)
+ responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain'))
+
+ srv = server.Site(responder_resource)
+
+ listeners = []
+
+ for host in self.hs.config.acme_bind_addresses:
+ logger.info(
+ "Listening for ACME requests on %s:%s", host, self.hs.config.acme_port
+ )
+ endpoint = serverFromString(
+ self.reactor, "tcp:%s:interface=%s" % (self.hs.config.acme_port, host)
+ )
+ listeners.append(endpoint.listen(srv))
+
+ # Make sure we are registered to the ACME server. There's no public API
+ # for this, it is usually triggered by startService, but since we don't
+ # want it to control where we save the certificates, we have to reach in
+ # and trigger the registration machinery ourselves.
+ self._issuer._registered = False
+ yield self._issuer._ensure_registered()
+
+ # Return a Deferred that will fire when all the servers have started up.
+ yield defer.DeferredList(listeners, fireOnOneErrback=True, consumeErrors=True)
+
+ @defer.inlineCallbacks
+ def provision_certificate(self):
+
+ logger.warning("Reprovisioning %s", self.hs.hostname)
+
+ try:
+ yield self._issuer.issue_cert(self.hs.hostname)
+ except Exception:
+ logger.exception("Fail!")
+ raise
+ logger.warning("Reprovisioned %s, saving.", self.hs.hostname)
+ cert_chain = self._store.certs[self.hs.hostname]
+
+ try:
+ with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:
+ for x in cert_chain:
+ if x.startswith(b"-----BEGIN RSA PRIVATE KEY-----"):
+ private_key_file.write(x)
+
+ with open(self.hs.config.tls_certificate_file, "wb") as certificate_file:
+ for x in cert_chain:
+ if x.startswith(b"-----BEGIN CERTIFICATE-----"):
+ certificate_file.write(x)
+ except Exception:
+ logger.exception("Failed saving!")
+ raise
+
+ defer.returnValue(True)
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 3c40999338..120815b09b 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -19,6 +19,7 @@ from six import iteritems
from twisted.internet import defer
+import synapse.metrics
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo
@@ -163,6 +164,11 @@ class UserDirectoryHandler(object):
yield self._handle_deltas(deltas)
self.pos = deltas[-1]["stream_id"]
+
+ # Expose current event processing position to prometheus
+ synapse.metrics.event_processing_positions.labels(
+ "user_dir").set(self.pos)
+
yield self.store.update_user_directory_stream_pos(self.pos)
@defer.inlineCallbacks
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 815f8ff2f7..cd79ebab62 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -13,15 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import random
import re
-from twisted.internet import defer
-from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet.error import ConnectError
-
-from synapse.http.federation.srv_resolver import Server, resolve_service
-
logger = logging.getLogger(__name__)
@@ -88,140 +81,3 @@ def parse_and_validate_server_name(server_name):
))
return host, port
-
-
-def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
- timeout=None):
- """Construct an endpoint for the given matrix destination.
-
- Args:
- reactor: Twisted reactor.
- destination (unicode): The name of the server to connect to.
- tls_client_options_factory
- (synapse.crypto.context_factory.ClientTLSOptionsFactory):
- Factory which generates TLS options for client connections.
- timeout (int): connection timeout in seconds
- """
-
- domain, port = parse_server_name(destination)
-
- endpoint_kw_args = {}
-
- if timeout is not None:
- endpoint_kw_args.update(timeout=timeout)
-
- if tls_client_options_factory is None:
- transport_endpoint = HostnameEndpoint
- default_port = 8008
- else:
- # the SNI string should be the same as the Host header, minus the port.
- # as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777,
- # the Host header and SNI should therefore be the server_name of the remote
- # server.
- tls_options = tls_client_options_factory.get_options(domain)
-
- def transport_endpoint(reactor, host, port, timeout):
- return wrapClientTLS(
- tls_options,
- HostnameEndpoint(reactor, host, port, timeout=timeout),
- )
- default_port = 8448
-
- if port is None:
- return SRVClientEndpoint(
- reactor, "matrix", domain, protocol="tcp",
- default_port=default_port, endpoint=transport_endpoint,
- endpoint_kw_args=endpoint_kw_args
- )
- else:
- return transport_endpoint(
- reactor, domain, port, **endpoint_kw_args
- )
-
-
-class SRVClientEndpoint(object):
- """An endpoint which looks up SRV records for a service.
- Cycles through the list of servers starting with each call to connect
- picking the next server.
- Implements twisted.internet.interfaces.IStreamClientEndpoint.
- """
-
- def __init__(self, reactor, service, domain, protocol="tcp",
- default_port=None, endpoint=HostnameEndpoint,
- endpoint_kw_args={}):
- self.reactor = reactor
- self.service_name = "_%s._%s.%s" % (service, protocol, domain)
-
- if default_port is not None:
- self.default_server = Server(
- host=domain,
- port=default_port,
- )
- else:
- self.default_server = None
-
- self.endpoint = endpoint
- self.endpoint_kw_args = endpoint_kw_args
-
- self.servers = None
- self.used_servers = None
-
- @defer.inlineCallbacks
- def fetch_servers(self):
- self.used_servers = []
- self.servers = yield resolve_service(self.service_name)
-
- def pick_server(self):
- if not self.servers:
- if self.used_servers:
- self.servers = self.used_servers
- self.used_servers = []
- self.servers.sort()
- elif self.default_server:
- return self.default_server
- else:
- raise ConnectError(
- "No server available for %s" % self.service_name
- )
-
- # look for all servers with the same priority
- min_priority = self.servers[0].priority
- weight_indexes = list(
- (index, server.weight + 1)
- for index, server in enumerate(self.servers)
- if server.priority == min_priority
- )
-
- total_weight = sum(weight for index, weight in weight_indexes)
- target_weight = random.randint(0, total_weight)
- for index, weight in weight_indexes:
- target_weight -= weight
- if target_weight <= 0:
- server = self.servers[index]
- # XXX: this looks totally dubious:
- #
- # (a) we never reuse a server until we have been through
- # all of the servers at the same priority, so if the
- # weights are A: 100, B:1, we always do ABABAB instead of
- # AAAA...AAAB (approximately).
- #
- # (b) After using all the servers at the lowest priority,
- # we move onto the next priority. We should only use the
- # second priority if servers at the top priority are
- # unreachable.
- #
- del self.servers[index]
- self.used_servers.append(server)
- return server
-
- @defer.inlineCallbacks
- def connect(self, protocolFactory):
- if self.servers is None:
- yield self.fetch_servers()
- server = self.pick_server()
- logger.info("Connecting to %s:%s", server.host, server.port)
- endpoint = self.endpoint(
- self.reactor, server.host, server.port, **self.endpoint_kw_args
- )
- connection = yield endpoint.connect(protocolFactory)
- defer.returnValue(connection)
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
new file mode 100644
index 0000000000..64c780a341
--- /dev/null
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -0,0 +1,124 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.web.client import URI, Agent, HTTPConnectionPool
+from twisted.web.iweb import IAgent
+
+from synapse.http.endpoint import parse_server_name
+from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
+from synapse.util.logcontext import make_deferred_yieldable
+
+logger = logging.getLogger(__name__)
+
+
+@implementer(IAgent)
+class MatrixFederationAgent(object):
+ """An Agent-like thing which provides a `request` method which will look up a matrix
+ server and send an HTTP request to it.
+
+ Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
+
+ Args:
+ reactor (IReactor): twisted reactor to use for underlying requests
+
+ tls_client_options_factory (ClientTLSOptionsFactory|None):
+ factory to use for fetching client tls options, or none to disable TLS.
+
+ srv_resolver (SrvResolver|None):
+ SRVResolver impl to use for looking up SRV records. None to use a default
+ implementation.
+ """
+
+ def __init__(
+ self, reactor, tls_client_options_factory, _srv_resolver=None,
+ ):
+ self._reactor = reactor
+ self._tls_client_options_factory = tls_client_options_factory
+ if _srv_resolver is None:
+ _srv_resolver = SrvResolver()
+ self._srv_resolver = _srv_resolver
+
+ self._pool = HTTPConnectionPool(reactor)
+ self._pool.retryAutomatically = False
+ self._pool.maxPersistentPerHost = 5
+ self._pool.cachedConnectionTimeout = 2 * 60
+
+ @defer.inlineCallbacks
+ def request(self, method, uri, headers=None, bodyProducer=None):
+ """
+ Args:
+ method (bytes): HTTP method: GET/POST/etc
+
+ uri (bytes): Absolute URI to be retrieved
+
+ headers (twisted.web.http_headers.Headers|None):
+ HTTP headers to send with the request, or None to
+ send no extra headers.
+
+ bodyProducer (twisted.web.iweb.IBodyProducer|None):
+ An object which can generate bytes to make up the
+ body of this request (for example, the properly encoded contents of
+ a file for a file upload). Or None if the request is to have
+ no body.
+
+ Returns:
+ Deferred[twisted.web.iweb.IResponse]:
+ fires when the header of the response has been received (regardless of the
+ response status code). Fails if there is any problem which prevents that
+ response from being received (including problems that prevent the request
+ from being sent).
+ """
+
+ parsed_uri = URI.fromBytes(uri)
+ server_name_bytes = parsed_uri.netloc
+ host, port = parse_server_name(server_name_bytes.decode("ascii"))
+
+ # XXX disabling TLS is really only supported here for the benefit of the
+ # unit tests. We should make the UTs cope with TLS rather than having to make
+ # the code support the unit tests.
+ if self._tls_client_options_factory is None:
+ tls_options = None
+ else:
+ tls_options = self._tls_client_options_factory.get_options(host)
+
+ if port is not None:
+ target = (host, port)
+ else:
+ server_list = yield self._srv_resolver.resolve_service(server_name_bytes)
+ if not server_list:
+ target = (host, 8448)
+ logger.debug("No SRV record for %s, using %s", host, target)
+ else:
+ target = pick_server_from_list(server_list)
+
+ class EndpointFactory(object):
+ @staticmethod
+ def endpointForURI(_uri):
+ logger.info("Connecting to %s:%s", target[0], target[1])
+ ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1])
+ if tls_options is not None:
+ ep = wrapClientTLS(tls_options, ep)
+ return ep
+
+ agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
+ res = yield make_deferred_yieldable(
+ agent.request(method, uri, headers, bodyProducer)
+ )
+ defer.returnValue(res)
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index c49b82c394..71830c549d 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+import random
import time
import attr
@@ -51,74 +52,118 @@ class Server(object):
expires = attr.ib(default=0)
-@defer.inlineCallbacks
-def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
- """Look up a SRV record, with caching
+def pick_server_from_list(server_list):
+ """Randomly choose a server from the server list
+
+ Args:
+ server_list (list[Server]): list of candidate servers
+
+ Returns:
+ Tuple[bytes, int]: (host, port) pair for the chosen server
+ """
+ if not server_list:
+ raise RuntimeError("pick_server_from_list called with empty list")
+
+ # TODO: currently we only use the lowest-priority servers. We should maintain a
+ # cache of servers known to be "down" and filter them out
+
+ min_priority = min(s.priority for s in server_list)
+ eligible_servers = list(s for s in server_list if s.priority == min_priority)
+ total_weight = sum(s.weight for s in eligible_servers)
+ target_weight = random.randint(0, total_weight)
+
+ for s in eligible_servers:
+ target_weight -= s.weight
+
+ if target_weight <= 0:
+ return s.host, s.port
+
+ # this should be impossible.
+ raise RuntimeError(
+ "pick_server_from_list got to end of eligible server list.",
+ )
+
+
+class SrvResolver(object):
+ """Interface to the dns client to do SRV lookups, with result caching.
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
but the cache never gets populated), so we add our own caching layer here.
Args:
- service_name (unicode|bytes): record to look up
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
cache (dict): cache object
- clock (object): clock implementation. must provide a time() method.
-
- Returns:
- Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
+ get_time (callable): clock implementation. Should return seconds since the epoch
"""
- # TODO: the dns client handles both unicode names (encoding via idna) and pre-encoded
- # byteses; however they will obviously end up as separate entries in the cache. We
- # should pick one form and stick with it.
- cache_entry = cache.get(service_name, None)
- if cache_entry:
- if all(s.expires > int(clock.time()) for s in cache_entry):
- servers = list(cache_entry)
- defer.returnValue(servers)
-
- try:
- answers, _, _ = yield make_deferred_yieldable(
- dns_client.lookupService(service_name),
- )
- except DNSNameError:
- # TODO: cache this. We can get the SOA out of the exception, and use
- # the negative-TTL value.
- defer.returnValue([])
- except DomainError as e:
- # We failed to resolve the name (other than a NameError)
- # Try something in the cache, else rereaise
- cache_entry = cache.get(service_name, None)
+ def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
+ self._dns_client = dns_client
+ self._cache = cache
+ self._get_time = get_time
+
+ @defer.inlineCallbacks
+ def resolve_service(self, service_name):
+ """Look up a SRV record
+
+ Args:
+ service_name (bytes): record to look up
+
+ Returns:
+ Deferred[list[Server]]:
+ a list of the SRV records, or an empty list if none found
+ """
+ now = int(self._get_time())
+
+ if not isinstance(service_name, bytes):
+ raise TypeError("%r is not a byte string" % (service_name,))
+
+ cache_entry = self._cache.get(service_name, None)
if cache_entry:
- logger.warn(
- "Failed to resolve %r, falling back to cache. %r",
- service_name, e
+ if all(s.expires > now for s in cache_entry):
+ servers = list(cache_entry)
+ defer.returnValue(servers)
+
+ try:
+ answers, _, _ = yield make_deferred_yieldable(
+ self._dns_client.lookupService(service_name),
)
- defer.returnValue(list(cache_entry))
- else:
- raise e
-
- if (len(answers) == 1
- and answers[0].type == dns.SRV
- and answers[0].payload
- and answers[0].payload.target == dns.Name(b'.')):
- raise ConnectError("Service %s unavailable" % service_name)
-
- servers = []
-
- for answer in answers:
- if answer.type != dns.SRV or not answer.payload:
- continue
-
- payload = answer.payload
-
- servers.append(Server(
- host=payload.target.name,
- port=payload.port,
- priority=payload.priority,
- weight=payload.weight,
- expires=int(clock.time()) + answer.ttl,
- ))
-
- servers.sort() # FIXME: get rid of this (it's broken by the attrs change)
- cache[service_name] = list(servers)
- defer.returnValue(servers)
+ except DNSNameError:
+ # TODO: cache this. We can get the SOA out of the exception, and use
+ # the negative-TTL value.
+ defer.returnValue([])
+ except DomainError as e:
+ # We failed to resolve the name (other than a NameError)
+ # Try something in the cache, else rereaise
+ cache_entry = self._cache.get(service_name, None)
+ if cache_entry:
+ logger.warn(
+ "Failed to resolve %r, falling back to cache. %r",
+ service_name, e
+ )
+ defer.returnValue(list(cache_entry))
+ else:
+ raise e
+
+ if (len(answers) == 1
+ and answers[0].type == dns.SRV
+ and answers[0].payload
+ and answers[0].payload.target == dns.Name(b'.')):
+ raise ConnectError("Service %s unavailable" % service_name)
+
+ servers = []
+
+ for answer in answers:
+ if answer.type != dns.SRV or not answer.payload:
+ continue
+
+ payload = answer.payload
+
+ servers.append(Server(
+ host=payload.target.name,
+ port=payload.port,
+ priority=payload.priority,
+ weight=payload.weight,
+ expires=now + answer.ttl,
+ ))
+
+ self._cache[service_name] = list(servers)
+ defer.returnValue(servers)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 250bb1ef91..980e912348 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -32,7 +32,7 @@ from twisted.internet import defer, protocol
from twisted.internet.error import DNSLookupError
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone
-from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool
+from twisted.web.client import FileBodyProducer
from twisted.web.http_headers import Headers
import synapse.metrics
@@ -44,7 +44,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
-from synapse.http.endpoint import matrix_federation_endpoint
+from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.util.async_helpers import timeout_deferred
from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.metrics import Measure
@@ -66,20 +66,6 @@ else:
MAXINT = sys.maxint
-class MatrixFederationEndpointFactory(object):
- def __init__(self, hs):
- self.reactor = hs.get_reactor()
- self.tls_client_options_factory = hs.tls_client_options_factory
-
- def endpointForURI(self, uri):
- destination = uri.netloc.decode('ascii')
-
- return matrix_federation_endpoint(
- self.reactor, destination, timeout=10,
- tls_client_options_factory=self.tls_client_options_factory
- )
-
-
_next_id = 1
@@ -187,12 +173,10 @@ class MatrixFederationHttpClient(object):
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
reactor = hs.get_reactor()
- pool = HTTPConnectionPool(reactor)
- pool.retryAutomatically = False
- pool.maxPersistentPerHost = 5
- pool.cachedConnectionTimeout = 2 * 60
- self.agent = Agent.usingEndpointFactory(
- reactor, MatrixFederationEndpointFactory(hs), pool=pool
+
+ self.agent = MatrixFederationAgent(
+ hs.get_reactor(),
+ hs.tls_client_options_factory,
)
self.clock = hs.get_clock()
self._store = hs.get_datastore()
@@ -316,9 +300,9 @@ class MatrixFederationHttpClient(object):
headers_dict[b"Authorization"] = auth_headers
logger.info(
- "{%s} [%s] Sending request: %s %s",
+ "{%s} [%s] Sending request: %s %s; timeout %fs",
request.txn_id, request.destination, request.method,
- url_str,
+ url_str, _sec_timeout,
)
try:
@@ -338,12 +322,11 @@ class MatrixFederationHttpClient(object):
reactor=self.hs.get_reactor(),
)
- response = yield make_deferred_yieldable(
- request_deferred,
- )
+ response = yield request_deferred
except DNSLookupError as e:
raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
except Exception as e:
+ logger.info("Failed to send request: %s", e)
raise_from(RequestSendFailed(e, can_retry=True), e)
logger.info(
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 882e844eb1..756721e304 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -79,6 +79,10 @@ CONDITIONAL_REQUIREMENTS = {
# ConsentResource uses select_autoescape, which arrived in jinja 2.9
"resources.consent": ["Jinja2>=2.9"],
+ # ACME support is required to provision TLS certificates from authorities
+ # that use the protocol, such as Let's Encrypt.
+ "acme": ["txacme>=0.9.2"],
+
"saml2": ["pysaml2>=4.5.0"],
"url_preview": ["lxml>=3.5.0"],
"test": ["mock>=2.0"],
diff --git a/synapse/server.py b/synapse/server.py
index 9985687b95..c8914302cf 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -46,6 +46,7 @@ from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler
from synapse.handlers import Handlers
+from synapse.handlers.acme import AcmeHandler
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
from synapse.handlers.deactivate_account import DeactivateAccountHandler
@@ -129,6 +130,7 @@ class HomeServer(object):
'sync_handler',
'typing_handler',
'room_list_handler',
+ 'acme_handler',
'auth_handler',
'device_handler',
'e2e_keys_handler',
@@ -310,6 +312,9 @@ class HomeServer(object):
def build_e2e_room_keys_handler(self):
return E2eRoomKeysHandler(self)
+ def build_acme_handler(self):
+ return AcmeHandler(self)
+
def build_application_service_api(self):
return ApplicationServiceApi(self)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 865b5e915a..f62f70b9f1 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -26,6 +26,7 @@ from prometheus_client import Histogram
from twisted.internet import defer
from synapse.api.errors import StoreError
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.engines import PostgresEngine
from synapse.util.caches.descriptors import Cache
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
@@ -192,6 +193,51 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine
+ # A set of tables that are not safe to use native upserts in.
+ self._unsafe_to_upsert_tables = {"user_ips"}
+
+ if self.database_engine.can_native_upsert:
+ # Check ASAP (and then later, every 1s) to see if we have finished
+ # background updates of tables that aren't safe to update.
+ self._clock.call_later(
+ 0.0,
+ run_as_background_process,
+ "upsert_safety_check",
+ self._check_safe_to_upsert
+ )
+
+ @defer.inlineCallbacks
+ def _check_safe_to_upsert(self):
+ """
+ Is it safe to use native UPSERT?
+
+ If there are background updates, we will need to wait, as they may be
+ the addition of indexes that set the UNIQUE constraint that we require.
+
+ If the background updates have not completed, wait 15 sec and check again.
+ """
+ updates = yield self._simple_select_list(
+ "background_updates",
+ keyvalues=None,
+ retcols=["update_name"],
+ desc="check_background_updates",
+ )
+ updates = [x["update_name"] for x in updates]
+
+ # The User IPs table in schema #53 was missing a unique index, which we
+ # run as a background update.
+ if "user_ips_device_unique_index" not in updates:
+ self._unsafe_to_upsert_tables.discard("user_ips")
+
+ # If there's any tables left to check, reschedule to run.
+ if self._unsafe_to_upsert_tables:
+ self._clock.call_later(
+ 15.0,
+ run_as_background_process,
+ "upsert_safety_check",
+ self._check_safe_to_upsert
+ )
+
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -494,8 +540,15 @@ class SQLBaseStore(object):
txn.executemany(sql, vals)
@defer.inlineCallbacks
- def _simple_upsert(self, table, keyvalues, values,
- insertion_values={}, desc="_simple_upsert", lock=True):
+ def _simple_upsert(
+ self,
+ table,
+ keyvalues,
+ values,
+ insertion_values={},
+ desc="_simple_upsert",
+ lock=True
+ ):
"""
`lock` should generally be set to True (the default), but can be set
@@ -516,16 +569,21 @@ class SQLBaseStore(object):
inserting
lock (bool): True to lock the table when doing the upsert.
Returns:
- Deferred(bool): True if a new entry was created, False if an
- existing one was updated.
+ Deferred(None or bool): Native upserts always return None. Emulated
+ upserts return True if a new entry was created, False if an existing
+ one was updated.
"""
attempts = 0
while True:
try:
result = yield self.runInteraction(
desc,
- self._simple_upsert_txn, table, keyvalues, values, insertion_values,
- lock=lock
+ self._simple_upsert_txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values,
+ lock=lock,
)
defer.returnValue(result)
except self.database_engine.module.IntegrityError as e:
@@ -537,12 +595,71 @@ class SQLBaseStore(object):
# presumably we raced with another transaction: let's retry.
logger.warn(
- "IntegrityError when upserting into %s; retrying: %s",
- table, e
+ "%s when upserting into %s; retrying: %s", e.__name__, table, e
)
- def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
- lock=True):
+ def _simple_upsert_txn(
+ self,
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values={},
+ lock=True,
+ ):
+ """
+ Pick the UPSERT method which works best on the platform. Either the
+ native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
+
+ Args:
+ txn: The transaction to use.
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ None or bool: Native upserts always return None. Emulated
+ upserts return True if a new entry was created, False if an existing
+ one was updated.
+ """
+ if (
+ self.database_engine.can_native_upsert
+ and table not in self._unsafe_to_upsert_tables
+ ):
+ return self._simple_upsert_txn_native_upsert(
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values=insertion_values,
+ )
+ else:
+ return self._simple_upsert_txn_emulated(
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values=insertion_values,
+ lock=lock,
+ )
+
+ def _simple_upsert_txn_emulated(
+ self, txn, table, keyvalues, values, insertion_values={}, lock=True
+ ):
+ """
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ bool: Return True if a new entry was created, False if an existing
+ one was updated.
+ """
# We need to lock the table :(, unless we're *really* careful
if lock:
self.database_engine.lock_table(txn, table)
@@ -577,12 +694,44 @@ class SQLBaseStore(object):
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in allvalues),
- ", ".join("?" for _ in allvalues)
+ ", ".join("?" for _ in allvalues),
)
txn.execute(sql, list(allvalues.values()))
# successfully inserted
return True
+ def _simple_upsert_txn_native_upsert(
+ self, txn, table, keyvalues, values, insertion_values={}
+ ):
+ """
+ Use the native UPSERT functionality in recent PostgreSQL versions.
+
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ Returns:
+ None
+ """
+ allvalues = {}
+ allvalues.update(keyvalues)
+ allvalues.update(values)
+ allvalues.update(insertion_values)
+
+ sql = (
+ "INSERT INTO %s (%s) VALUES (%s) "
+ "ON CONFLICT (%s) DO UPDATE SET %s"
+ ) % (
+ table,
+ ", ".join(k for k in allvalues),
+ ", ".join("?" for _ in allvalues),
+ ", ".join(k for k in keyvalues),
+ ", ".join(k + "=EXCLUDED." + k for k in values),
+ )
+ txn.execute(sql, list(allvalues.values()))
+
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 5d548f250a..091d7116c5 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -110,8 +110,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
@defer.inlineCallbacks
def _remove_user_ip_dupes(self, progress, batch_size):
+ # This works function works by scanning the user_ips table in batches
+ # based on `last_seen`. For each row in a batch it searches the rest of
+ # the table to see if there are any duplicates, if there are then they
+ # are removed and replaced with a suitable row.
- last_seen_progress = progress.get("last_seen", 0)
+ # Fetch the start of the batch
+ begin_last_seen = progress.get("last_seen", 0)
def get_last_seen(txn):
txn.execute(
@@ -122,29 +127,28 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
LIMIT 1
OFFSET ?
""",
- (last_seen_progress, batch_size)
+ (begin_last_seen, batch_size)
)
- results = txn.fetchone()
- return results
-
- # Get a last seen that's sufficiently far away enough from the last one
- last_seen = yield self.runInteraction(
+ row = txn.fetchone()
+ if row:
+ return row[0]
+ else:
+ return None
+
+ # Get a last seen that has roughly `batch_size` since `begin_last_seen`
+ end_last_seen = yield self.runInteraction(
"user_ips_dups_get_last_seen", get_last_seen
)
- if not last_seen:
- # If we get a None then we're reaching the end and just need to
- # delete the last batch.
- last = True
+ # If it returns None, then we're processing the last batch
+ last = end_last_seen is None
- # We fake not having an upper bound by using a future date, by
- # just multiplying the current time by two....
- last_seen = int(self.clock.time_msec()) * 2
- else:
- last = False
- last_seen = last_seen[0]
+ logger.info(
+ "Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s",
+ begin_last_seen, end_last_seen,
+ )
- def remove(txn, last_seen_progress, last_seen):
+ def remove(txn):
# This works by looking at all entries in the given time span, and
# then for each (user_id, access_token, ip) tuple in that range
# checking for any duplicates in the rest of the table (via a join).
@@ -153,6 +157,16 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
# all other duplicates.
# It is efficient due to the existence of (user_id, access_token,
# ip) and (last_seen) indices.
+
+ # Define the search space, which requires handling the last batch in
+ # a different way
+ if last:
+ clause = "? <= last_seen"
+ args = (begin_last_seen,)
+ else:
+ clause = "? <= last_seen AND last_seen < ?"
+ args = (begin_last_seen, end_last_seen)
+
txn.execute(
"""
SELECT user_id, access_token, ip,
@@ -160,13 +174,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
FROM (
SELECT user_id, access_token, ip
FROM user_ips
- WHERE ? <= last_seen AND last_seen < ?
- ORDER BY last_seen
+ WHERE {}
) c
INNER JOIN user_ips USING (user_id, access_token, ip)
GROUP BY user_id, access_token, ip
- HAVING count(*) > 1""",
- (last_seen_progress, last_seen)
+ HAVING count(*) > 1
+ """.format(clause),
+ args
)
res = txn.fetchall()
@@ -194,12 +208,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
)
self._background_update_progress_txn(
- txn, "user_ips_remove_dupes", {"last_seen": last_seen}
+ txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
)
- yield self.runInteraction(
- "user_ips_dups_remove", remove, last_seen_progress, last_seen
- )
+ yield self.runInteraction("user_ips_dups_remove", remove)
+
if last:
yield self._end_background_update("user_ips_remove_dupes")
@@ -244,7 +257,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
)
def _update_client_ips_batch_txn(self, txn, to_update):
- self.database_engine.lock_table(txn, "user_ips")
+ if "user_ips" in self._unsafe_to_upsert_tables or (
+ not self.database_engine.can_native_upsert
+ ):
+ self.database_engine.lock_table(txn, "user_ips")
for entry in iteritems(to_update):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index e2f9de8451..ff5ef97ca8 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -18,7 +18,7 @@ import platform
from ._base import IncorrectDatabaseSetup
from .postgres import PostgresEngine
-from .sqlite3 import Sqlite3Engine
+from .sqlite import Sqlite3Engine
SUPPORTED_MODULE = {
"sqlite3": Sqlite3Engine,
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 42225f8a2a..4004427c7b 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -38,6 +38,13 @@ class PostgresEngine(object):
return sql.replace("?", "%s")
def on_new_connection(self, db_conn):
+
+ # Get the version of PostgreSQL that we're using. As per the psycopg2
+ # docs: The number is formed by converting the major, minor, and
+ # revision numbers into two-decimal-digit numbers and appending them
+ # together. For example, version 8.1.5 will be returned as 80105
+ self._version = db_conn.server_version
+
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
@@ -54,6 +61,13 @@ class PostgresEngine(object):
cursor.close()
+ @property
+ def can_native_upsert(self):
+ """
+ Can we use native UPSERTs? This requires PostgreSQL 9.5+.
+ """
+ return self._version >= 90500
+
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite.py
index 19949fc474..c64d73ff21 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite.py
@@ -15,6 +15,7 @@
import struct
import threading
+from sqlite3 import sqlite_version_info
from synapse.storage.prepare_database import prepare_database
@@ -30,6 +31,14 @@ class Sqlite3Engine(object):
self._current_state_group_id = None
self._current_state_group_id_lock = threading.Lock()
+ @property
+ def can_native_upsert(self):
+ """
+ Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
+ more work we haven't done yet to tell what was inserted vs updated.
+ """
+ return sqlite_version_info >= (3, 24, 0)
+
def check_database(self, txn):
pass
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 2743b52bad..134297e284 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -215,7 +215,7 @@ class PusherStore(PusherWorkerStore):
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so _simple_upsert will retry
- newly_inserted = yield self._simple_upsert(
+ yield self._simple_upsert(
table="pushers",
keyvalues={
"app_id": app_id,
@@ -238,7 +238,12 @@ class PusherStore(PusherWorkerStore):
lock=False,
)
- if newly_inserted:
+ user_has_pusher = self.get_if_user_has_pusher.cache.get(
+ (user_id,), None, update_metrics=False
+ )
+
+ if user_has_pusher is not True:
+ # invalidate, since we the user might not have had a pusher before
yield self.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index a8781b0e5d..ce48212265 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -168,14 +168,14 @@ class UserDirectoryStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine):
# We weight the localpart most highly, then display name and finally
# server name
- if new_entry:
+ if self.database_engine.can_native_upsert:
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
setweight(to_tsvector('english', ?), 'A')
|| setweight(to_tsvector('english', ?), 'D')
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
- )
+ ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
"""
txn.execute(
sql,
@@ -185,20 +185,45 @@ class UserDirectoryStore(SQLBaseStore):
)
)
else:
- sql = """
- UPDATE user_directory_search
- SET vector = setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
- WHERE user_id = ?
- """
- txn.execute(
- sql,
- (
- get_localpart_from_id(user_id), get_domain_from_id(user_id),
- display_name, user_id,
+ # TODO: Remove this code after we've bumped the minimum version
+ # of postgres to always support upserts, so we can get rid of
+ # `new_entry` usage
+ if new_entry is True:
+ sql = """
+ INSERT INTO user_directory_search(user_id, vector)
+ VALUES (?,
+ setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ )
+ """
+ txn.execute(
+ sql,
+ (
+ user_id, get_localpart_from_id(user_id),
+ get_domain_from_id(user_id), display_name,
+ )
+ )
+ elif new_entry is False:
+ sql = """
+ UPDATE user_directory_search
+ SET vector = setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ WHERE user_id = ?
+ """
+ txn.execute(
+ sql,
+ (
+ get_localpart_from_id(user_id),
+ get_domain_from_id(user_id),
+ display_name, user_id,
+ )
+ )
+ else:
+ raise RuntimeError(
+ "upsert returned None when 'can_native_upsert' is False"
)
- )
elif isinstance(self.database_engine, Sqlite3Engine):
value = "%s %s" % (user_id, display_name,) if display_name else user_id
self._simple_upsert_txn(
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
new file mode 100644
index 0000000000..7a3881f558
--- /dev/null
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -0,0 +1,240 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import Mock
+
+import treq
+
+from twisted.internet import defer
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.test.ssl_helpers import ServerTLSContext
+from twisted.web.http import HTTPChannel
+
+from synapse.crypto.context_factory import ClientTLSOptionsFactory
+from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.util.logcontext import LoggingContext
+
+from tests.server import FakeTransport, ThreadedMemoryReactorClock
+from tests.unittest import TestCase
+
+logger = logging.getLogger(__name__)
+
+
+class MatrixFederationAgentTests(TestCase):
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ self.mock_resolver = Mock()
+
+ self.agent = MatrixFederationAgent(
+ reactor=self.reactor,
+ tls_client_options_factory=ClientTLSOptionsFactory(None),
+ _srv_resolver=self.mock_resolver,
+ )
+
+ def _make_connection(self, client_factory, expected_sni):
+ """Builds a test server, and completes the outgoing client connection
+
+ Returns:
+ HTTPChannel: the test server
+ """
+
+ # build the test server
+ server_tls_protocol = _build_test_server()
+
+ # now, tell the client protocol factory to build the client protocol (it will be a
+ # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
+ # HTTP11ClientProtocol) and wire the output of said protocol up to the server via
+ # a FakeTransport.
+ #
+ # Normally this would be done by the TCP socket code in Twisted, but we are
+ # stubbing that out here.
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(FakeTransport(server_tls_protocol, self.reactor))
+
+ # tell the server tls protocol to send its stuff back to the client, too
+ server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
+
+ # give the reactor a pump to get the TLS juices flowing.
+ self.reactor.pump((0.1,))
+
+ # check the SNI
+ server_name = server_tls_protocol._tlsConnection.get_servername()
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ # fish the test server back out of the server-side TLS protocol.
+ return server_tls_protocol.wrappedProtocol
+
+ @defer.inlineCallbacks
+ def _make_get_request(self, uri):
+ """
+ Sends a simple GET request via the agent, and checks its logcontext management
+ """
+ with LoggingContext("one") as context:
+ fetch_d = self.agent.request(b'GET', uri)
+
+ # Nothing happened yet
+ self.assertNoResult(fetch_d)
+
+ # should have reset logcontext to the sentinel
+ _check_logcontext(LoggingContext.sentinel)
+
+ try:
+ fetch_res = yield fetch_d
+ defer.returnValue(fetch_res)
+ finally:
+ _check_logcontext(context)
+
+ def test_get(self):
+ """
+ happy-path test of a GET request
+ """
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+ test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8448)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b'host'),
+ [b'testserv:8448']
+ )
+ content = request.content.read()
+ self.assertEqual(content, b'')
+
+ # Deferred is still without a result
+ self.assertNoResult(test_d)
+
+ # send the headers
+ request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json'])
+ request.write('')
+
+ self.reactor.pump((0.1,))
+
+ response = self.successResultOf(test_d)
+
+ # that should give us a Response object
+ self.assertEqual(response.code, 200)
+
+ # Send the body
+ request.write('{ "a": 1 }'.encode('ascii'))
+ request.finish()
+
+ self.reactor.pump((0.1,))
+
+ # check it can be read
+ json = self.successResultOf(treq.json_content(response))
+ self.assertEqual(json, {"a": 1})
+
+ def test_get_ip_address(self):
+ """
+ Test the behaviour when the server name contains an explicit IP (with no port)
+ """
+
+ # the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?)
+ self.mock_resolver.resolve_service.side_effect = lambda _: []
+
+ # then there will be a getaddrinfo on the IP
+ self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ self.mock_resolver.resolve_service.assert_called_once()
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8448)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=None,
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+ # XXX currently broken
+ # self.assertEqual(
+ # request.requestHeaders.getRawHeaders(b'host'),
+ # [b'1.2.3.4:8448']
+ # )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+
+def _check_logcontext(context):
+ current = LoggingContext.current_context()
+ if current is not context:
+ raise AssertionError(
+ "Expected logcontext %s but was %s" % (context, current),
+ )
+
+
+def _build_test_server():
+ """Construct a test server
+
+ This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
+
+ Returns:
+ TLSMemoryBIOProtocol
+ """
+ server_factory = Factory.forProtocol(HTTPChannel)
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ server_tls_factory = TLSMemoryBIOFactory(
+ ServerTLSContext(), isClient=False, wrappedFactory=server_factory,
+ )
+
+ return server_tls_factory.buildProtocol(None)
+
+
+def _log_request(request):
+ """Implements Factory.log, which is expected by Request.finish"""
+ logger.info("Completed request %s", request)
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 1271a495e1..a872e2441e 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -21,7 +21,7 @@ from twisted.internet.defer import Deferred
from twisted.internet.error import ConnectError
from twisted.names import dns, error
-from synapse.http.federation.srv_resolver import resolve_service
+from synapse.http.federation.srv_resolver import SrvResolver
from synapse.util.logcontext import LoggingContext
from tests import unittest
@@ -43,13 +43,13 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock.lookupService.return_value = result_deferred
cache = {}
+ resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
@defer.inlineCallbacks
def do_lookup():
+
with LoggingContext("one") as ctx:
- resolve_d = resolve_service(
- service_name, dns_client=dns_client_mock, cache=cache
- )
+ resolve_d = resolver.resolve_service(service_name)
self.assertNoResult(resolve_d)
@@ -83,16 +83,15 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
- service_name = "test_service.example.com"
+ service_name = b"test_service.example.com"
entry = Mock(spec_set=["expires"])
entry.expires = 0
cache = {service_name: [entry]}
+ resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- servers = yield resolve_service(
- service_name, dns_client=dns_client_mock, cache=cache
- )
+ servers = yield resolver.resolve_service(service_name)
dns_client_mock.lookupService.assert_called_once_with(service_name)
@@ -106,17 +105,18 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock(spec_set=['lookupService'])
dns_client_mock.lookupService = Mock(spec_set=[])
- service_name = "test_service.example.com"
+ service_name = b"test_service.example.com"
entry = Mock(spec_set=["expires"])
entry.expires = 999999999
cache = {service_name: [entry]}
-
- servers = yield resolve_service(
- service_name, dns_client=dns_client_mock, cache=cache, clock=clock
+ resolver = SrvResolver(
+ dns_client=dns_client_mock, cache=cache, get_time=clock.time,
)
+ servers = yield resolver.resolve_service(service_name)
+
self.assertFalse(dns_client_mock.lookupService.called)
self.assertEquals(len(servers), 1)
@@ -128,12 +128,13 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
- service_name = "test_service.example.com"
+ service_name = b"test_service.example.com"
cache = {}
+ resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
with self.assertRaises(error.DNSServerError):
- yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache)
+ yield resolver.resolve_service(service_name)
@defer.inlineCallbacks
def test_name_error(self):
@@ -141,13 +142,12 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
- service_name = "test_service.example.com"
+ service_name = b"test_service.example.com"
cache = {}
+ resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- servers = yield resolve_service(
- service_name, dns_client=dns_client_mock, cache=cache
- )
+ servers = yield resolver.resolve_service(service_name)
self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0)
@@ -162,10 +162,9 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred
cache = {}
+ resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- resolve_d = resolve_service(
- service_name, dns_client=dns_client_mock, cache=cache
- )
+ resolve_d = resolver.resolve_service(service_name)
self.assertNoResult(resolve_d)
# returning a single "." should make the lookup fail with a ConenctError
@@ -187,10 +186,9 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred
cache = {}
+ resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- resolve_d = resolve_service(
- service_name, dns_client=dns_client_mock, cache=cache
- )
+ resolve_d = resolver.resolve_service(service_name)
self.assertNoResult(resolve_d)
lookup_deferred.callback((
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 8426eee400..d37f8f9981 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -15,6 +15,7 @@
from mock import Mock
+from twisted.internet import defer
from twisted.internet.defer import TimeoutError
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
from twisted.test.proto_helpers import StringTransport
@@ -26,11 +27,20 @@ from synapse.http.matrixfederationclient import (
MatrixFederationHttpClient,
MatrixFederationRequest,
)
+from synapse.util.logcontext import LoggingContext
from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase
+def check_logcontext(context):
+ current = LoggingContext.current_context()
+ if current is not context:
+ raise AssertionError(
+ "Expected logcontext %s but was %s" % (context, current),
+ )
+
+
class FederationClientTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
@@ -43,6 +53,70 @@ class FederationClientTests(HomeserverTestCase):
self.cl = MatrixFederationHttpClient(self.hs)
self.reactor.lookups["testserv"] = "1.2.3.4"
+ def test_client_get(self):
+ """
+ happy-path test of a GET request
+ """
+ @defer.inlineCallbacks
+ def do_request():
+ with LoggingContext("one") as context:
+ fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(fetch_d)
+
+ # should have reset logcontext to the sentinel
+ check_logcontext(LoggingContext.sentinel)
+
+ try:
+ fetch_res = yield fetch_d
+ defer.returnValue(fetch_res)
+ finally:
+ check_logcontext(context)
+
+ test_d = do_request()
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8008)
+
+ # complete the connection and wire it up to a fake transport
+ protocol = factory.buildProtocol(None)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ # that should have made it send the request to the transport
+ self.assertRegex(transport.value(), b"^GET /foo/bar")
+
+ # Deferred is still without a result
+ self.assertNoResult(test_d)
+
+ # Send it the HTTP response
+ res_json = '{ "a": 1 }'.encode('ascii')
+ protocol.dataReceived(
+ b"HTTP/1.1 200 OK\r\n"
+ b"Server: Fake\r\n"
+ b"Content-Type: application/json\r\n"
+ b"Content-Length: %i\r\n"
+ b"\r\n"
+ b"%s" % (len(res_json), res_json)
+ )
+
+ self.pump()
+
+ res = self.successResultOf(test_d)
+
+ # check the response is as expected
+ self.assertEqual(res, {"a": 1})
+
def test_dns_error(self):
"""
If the DNS lookup returns an error, it will bubble up.
@@ -54,6 +128,28 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
+ def test_client_connection_refused(self):
+ d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(d)
+
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8008)
+ e = Exception("go away")
+ factory.clientConnectionFailed(None, e)
+ self.pump(0.5)
+
+ f = self.failureResultOf(d)
+
+ self.assertIsInstance(f.value, RequestSendFailed)
+ self.assertIs(f.value.inner_exception, e)
+
def test_client_never_connect(self):
"""
If the HTTP request is not connected and is timed out, it'll give a
diff --git a/tests/server.py b/tests/server.py
index db43fa0db8..ed2a046ae6 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,4 +1,5 @@
import json
+import logging
from io import BytesIO
from six import text_type
@@ -22,6 +23,8 @@ from synapse.util import Clock
from tests.utils import setup_test_homeserver as _sth
+logger = logging.getLogger(__name__)
+
class TimedOutException(Exception):
"""
@@ -339,7 +342,7 @@ def get_clock():
return (clock, hs_clock)
-@attr.s
+@attr.s(cmp=False)
class FakeTransport(object):
"""
A twisted.internet.interfaces.ITransport implementation which sends all its data
@@ -414,6 +417,11 @@ class FakeTransport(object):
self.buffer = self.buffer + byt
def _write():
+ if not self.buffer:
+ # nothing to do. Don't write empty buffers: it upsets the
+ # TLSMemoryBIOProtocol
+ return
+
if getattr(self.other, "transport") is not None:
self.other.dataReceived(self.buffer)
self.buffer = b""
@@ -421,7 +429,10 @@ class FakeTransport(object):
self._reactor.callLater(0.0, _write)
- _write()
+ # always actually do the write asynchronously. Some protocols (notably the
+ # TLSMemoryBIOProtocol) get very confused if a read comes back while they are
+ # still doing a write. Doing a callLater here breaks the cycle.
+ self._reactor.callLater(0.0, _write)
def writeSequence(self, seq):
for x in seq:
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 829f47d2e8..452d76ddd5 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -49,6 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.db_pool.runWithConnection = runWithConnection
config = Mock()
+ config._enable_native_upserts = False
config.event_cache_size = 1
config.database_config = {"name": "sqlite3"}
hs = TestHomeServer(
diff --git a/tests/test_server.py b/tests/test_server.py
index 634a8fbca5..08fb3fe02f 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -19,7 +19,7 @@ from six import StringIO
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
+from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
@@ -30,12 +30,18 @@ from synapse.util import Clock
from synapse.util.logcontext import make_deferred_yieldable
from tests import unittest
-from tests.server import FakeTransport, make_request, render, setup_test_homeserver
+from tests.server import (
+ FakeTransport,
+ ThreadedMemoryReactorClock,
+ make_request,
+ render,
+ setup_test_homeserver,
+)
class JsonResourceTests(unittest.TestCase):
def setUp(self):
- self.reactor = MemoryReactorClock()
+ self.reactor = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver(
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
diff --git a/tests/unittest.py b/tests/unittest.py
index 78d2f740f9..cda549c783 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -96,7 +96,7 @@ class TestCase(unittest.TestCase):
method = getattr(self, methodName)
- level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR))
+ level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING))
@around(self)
def setUp(orig):
@@ -333,7 +333,15 @@ class HomeserverTestCase(TestCase):
"""
kwargs = dict(kwargs)
kwargs.update(self._hs_args)
- return setup_test_homeserver(self.addCleanup, *args, **kwargs)
+ hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
+ stor = hs.get_datastore()
+
+ # Run the database background updates.
+ if hasattr(stor, "do_next_background_update"):
+ while not self.get_success(stor.has_completed_background_updates()):
+ self.get_success(stor.do_next_background_update(1))
+
+ return hs
def pump(self, by=0.0):
"""
diff --git a/tox.ini b/tox.ini
index a0f5486829..9b2d78ed6d 100644
--- a/tox.ini
+++ b/tox.ini
@@ -149,4 +149,5 @@ deps =
codecov
commands =
coverage combine
+ coverage xml
codecov -X gcov
|