summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rwxr-xr-xsynapse/app/homeserver.py56
-rw-r--r--synapse/config/_base.py4
-rw-r--r--synapse/config/registration.py9
-rw-r--r--synapse/config/tls.py115
-rw-r--r--synapse/handlers/acme.py147
-rw-r--r--synapse/python_dependencies.py4
-rw-r--r--synapse/rest/client/v2_alpha/register.py16
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/client_ips.py65
9 files changed, 349 insertions, 72 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index f3ac3d19f0..ffc49d77cc 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -13,10 +13,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import gc
 import logging
 import os
 import sys
+import traceback
 
 from six import iteritems
 
@@ -324,17 +326,12 @@ def setup(config_options):
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
-    tls_server_context_factory = context_factory.ServerContextFactory(config)
-    tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
     database_engine = create_engine(config.database_config)
     config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
 
     hs = SynapseHomeServer(
         config.server_name,
         db_config=config.database_config,
-        tls_server_context_factory=tls_server_context_factory,
-        tls_client_options_factory=tls_client_options_factory,
         config=config,
         version_string="Synapse/" + get_version_string(synapse),
         database_engine=database_engine,
@@ -361,12 +358,53 @@ def setup(config_options):
     logger.info("Database prepared in %s.", config.database_config['name'])
 
     hs.setup()
-    hs.start_listening()
 
+    @defer.inlineCallbacks
     def start():
-        hs.get_pusherpool().start()
-        hs.get_datastore().start_profiling()
-        hs.get_datastore().start_doing_background_updates()
+        try:
+            # Check if the certificate is still valid.
+            cert_days_remaining = hs.config.is_disk_cert_valid()
+
+            if hs.config.acme_enabled:
+                # If ACME is enabled, we might need to provision a certificate
+                # before starting.
+                acme = hs.get_acme_handler()
+
+                # Start up the webservices which we will respond to ACME
+                # challenges with.
+                yield acme.start_listening()
+
+                # We want to reprovision if cert_days_remaining is None (meaning no
+                # certificate exists), or the days remaining number it returns
+                # is less than our re-registration threshold.
+                if (cert_days_remaining is None) or (
+                    not cert_days_remaining > hs.config.acme_reprovision_threshold
+                ):
+                    yield acme.provision_certificate()
+
+            # Read the certificate from disk and build the context factories for
+            # TLS.
+            hs.config.read_certificate_from_disk()
+            hs.tls_server_context_factory = context_factory.ServerContextFactory(config)
+            hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
+                config
+            )
+
+            # It is now safe to start your Synapse.
+            hs.start_listening()
+            hs.get_pusherpool().start()
+            hs.get_datastore().start_profiling()
+            hs.get_datastore().start_doing_background_updates()
+        except Exception as e:
+            # If a DeferredList failed (like in listening on the ACME listener),
+            # we need to print the subfailure explicitly.
+            if isinstance(e, defer.FirstError):
+                e.subFailure.printTraceback(sys.stderr)
+                sys.exit(1)
+
+            # Something else went wrong when starting. Print it and bail out.
+            traceback.print_exc(file=sys.stderr)
+            sys.exit(1)
 
     reactor.callWhenRunning(start)
 
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index fd2d6d52ef..5858fb92b4 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -367,7 +367,7 @@ class Config(object):
         if not keys_directory:
             keys_directory = os.path.dirname(config_files[-1])
 
-        config_dir_path = os.path.abspath(keys_directory)
+        self.config_dir_path = os.path.abspath(keys_directory)
 
         specified_config = {}
         for config_file in config_files:
@@ -379,7 +379,7 @@ class Config(object):
 
         server_name = specified_config["server_name"]
         config_string = self.generate_config(
-            config_dir_path=config_dir_path,
+            config_dir_path=self.config_dir_path,
             data_dir_path=os.getcwd(),
             server_name=server_name,
             generate_secrets=False,
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 6c2b543b8c..fe520d6855 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -50,6 +50,10 @@ class RegistrationConfig(Config):
                 raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
         self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
 
+        self.disable_msisdn_registration = (
+            config.get("disable_msisdn_registration", False)
+        )
+
     def default_config(self, generate_secrets=False, **kwargs):
         if generate_secrets:
             registration_shared_secret = 'registration_shared_secret: "%s"' % (
@@ -70,6 +74,11 @@ class RegistrationConfig(Config):
         #     - email
         #     - msisdn
 
+        # Explicitly disable asking for MSISDNs from the registration
+        # flow (overrides registrations_require_3pid if MSISDNs are set as required)
+        #
+        # disable_msisdn_registration = True
+
         # Mandate that users are only allowed to associate certain formats of
         # 3PIDs with accounts on this server.
         #
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index bb8952c672..a75e233aa0 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -13,60 +13,110 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
 import os
+from datetime import datetime
 from hashlib import sha256
 
 from unpaddedbase64 import encode_base64
 
 from OpenSSL import crypto
 
-from ._base import Config
+from synapse.config._base import Config
+
+logger = logging.getLogger()
 
 
 class TlsConfig(Config):
     def read_config(self, config):
-        self.tls_certificate = self.read_tls_certificate(
-            config.get("tls_certificate_path")
-        )
-        self.tls_certificate_file = config.get("tls_certificate_path")
 
+        acme_config = config.get("acme", {})
+        self.acme_enabled = acme_config.get("enabled", False)
+        self.acme_url = acme_config.get(
+            "url", "https://acme-v01.api.letsencrypt.org/directory"
+        )
+        self.acme_port = acme_config.get("port", 8449)
+        self.acme_bind_addresses = acme_config.get("bind_addresses", ["127.0.0.1"])
+        self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
+
+        self.tls_certificate_file = os.path.abspath(config.get("tls_certificate_path"))
+        self.tls_private_key_file = os.path.abspath(config.get("tls_private_key_path"))
+        self._original_tls_fingerprints = config["tls_fingerprints"]
+        self.tls_fingerprints = list(self._original_tls_fingerprints)
         self.no_tls = config.get("no_tls", False)
 
-        if self.no_tls:
-            self.tls_private_key = None
-        else:
-            self.tls_private_key = self.read_tls_private_key(
-                config.get("tls_private_key_path")
-            )
+        # This config option applies to non-federation HTTP clients
+        # (e.g. for talking to recaptcha, identity servers, and such)
+        # It should never be used in production, and is intended for
+        # use only when running tests.
+        self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
+            "use_insecure_ssl_client_just_for_testing_do_not_use"
+        )
+
+        self.tls_certificate = None
+        self.tls_private_key = None
+
+    def is_disk_cert_valid(self):
+        """
+        Is the certificate we have on disk valid, and if so, for how long?
+
+        Returns:
+            int: Days remaining of certificate validity.
+            None: No certificate exists.
+        """
+        if not os.path.exists(self.tls_certificate_file):
+            return None
+
+        try:
+            with open(self.tls_certificate_file, 'rb') as f:
+                cert_pem = f.read()
+        except Exception:
+            logger.exception("Failed to read existing certificate off disk!")
+            raise
+
+        try:
+            tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
+        except Exception:
+            logger.exception("Failed to parse existing certificate off disk!")
+            raise
+
+        # YYYYMMDDhhmmssZ -- in UTC
+        expires_on = datetime.strptime(
+            tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ"
+        )
+        now = datetime.utcnow()
+        days_remaining = (expires_on - now).days
+        return days_remaining
 
-        self.tls_fingerprints = config["tls_fingerprints"]
+    def read_certificate_from_disk(self):
+        """
+        Read the certificates from disk.
+        """
+        self.tls_certificate = self.read_tls_certificate(self.tls_certificate_file)
+
+        if not self.no_tls:
+            self.tls_private_key = self.read_tls_private_key(self.tls_private_key_file)
+
+        self.tls_fingerprints = list(self._original_tls_fingerprints)
 
         # Check that our own certificate is included in the list of fingerprints
         # and include it if it is not.
         x509_certificate_bytes = crypto.dump_certificate(
-            crypto.FILETYPE_ASN1,
-            self.tls_certificate
+            crypto.FILETYPE_ASN1, self.tls_certificate
         )
         sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
         sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
         if sha256_fingerprint not in sha256_fingerprints:
             self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
 
-        # This config option applies to non-federation HTTP clients
-        # (e.g. for talking to recaptcha, identity servers, and such)
-        # It should never be used in production, and is intended for
-        # use only when running tests.
-        self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
-            "use_insecure_ssl_client_just_for_testing_do_not_use"
-        )
-
     def default_config(self, config_dir_path, server_name, **kwargs):
         base_key_name = os.path.join(config_dir_path, server_name)
 
         tls_certificate_path = base_key_name + ".tls.crt"
         tls_private_key_path = base_key_name + ".tls.key"
 
-        return """\
+        return (
+            """\
         # PEM encoded X509 certificate for TLS.
         # You can replace the self-signed certificate that synapse
         # autogenerates on launch with your own SSL certificate + key pair
@@ -107,7 +157,24 @@ class TlsConfig(Config):
         #
         tls_fingerprints: []
         # tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
-        """ % locals()
+
+        ## Support for ACME certificate auto-provisioning.
+        # acme:
+        #    enabled: false
+        ##   ACME path.
+        ##   If you only want to test, use the staging url:
+        ##   https://acme-staging.api.letsencrypt.org/directory
+        #    url: 'https://acme-v01.api.letsencrypt.org/directory'
+        ##   Port number (to listen for the HTTP-01 challenge).
+        ##   Using port 80 requires utilising something like authbind, or proxying to it.
+        #    port: 8449
+        ##   Hosts to bind to.
+        #    bind_addresses: ['127.0.0.1']
+        ##   How many days remaining on a certificate before it is renewed.
+        #    reprovision_threshold: 30
+        """
+            % locals()
+        )
 
     def read_tls_certificate(self, cert_path):
         cert_pem = self.read_file(cert_path, "tls_certificate")
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
new file mode 100644
index 0000000000..73ea7ed018
--- /dev/null
+++ b/synapse/handlers/acme.py
@@ -0,0 +1,147 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+import attr
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.internet.endpoints import serverFromString
+from twisted.python.filepath import FilePath
+from twisted.python.url import URL
+from twisted.web import server, static
+from twisted.web.resource import Resource
+
+logger = logging.getLogger(__name__)
+
+try:
+    from txacme.interfaces import ICertificateStore
+
+    @attr.s
+    @implementer(ICertificateStore)
+    class ErsatzStore(object):
+        """
+        A store that only stores in memory.
+        """
+
+        certs = attr.ib(default=attr.Factory(dict))
+
+        def store(self, server_name, pem_objects):
+            self.certs[server_name] = [o.as_bytes() for o in pem_objects]
+            return defer.succeed(None)
+
+
+except ImportError:
+    # txacme is missing
+    pass
+
+
+class AcmeHandler(object):
+    def __init__(self, hs):
+        self.hs = hs
+        self.reactor = hs.get_reactor()
+
+    @defer.inlineCallbacks
+    def start_listening(self):
+
+        # Configure logging for txacme, if you need to debug
+        # from eliot import add_destinations
+        # from eliot.twisted import TwistedDestination
+        #
+        # add_destinations(TwistedDestination())
+
+        from txacme.challenges import HTTP01Responder
+        from txacme.service import AcmeIssuingService
+        from txacme.endpoint import load_or_create_client_key
+        from txacme.client import Client
+        from josepy.jwa import RS256
+
+        self._store = ErsatzStore()
+        responder = HTTP01Responder()
+
+        self._issuer = AcmeIssuingService(
+            cert_store=self._store,
+            client_creator=(
+                lambda: Client.from_url(
+                    reactor=self.reactor,
+                    url=URL.from_text(self.hs.config.acme_url),
+                    key=load_or_create_client_key(
+                        FilePath(self.hs.config.config_dir_path)
+                    ),
+                    alg=RS256,
+                )
+            ),
+            clock=self.reactor,
+            responders=[responder],
+        )
+
+        well_known = Resource()
+        well_known.putChild(b'acme-challenge', responder.resource)
+        responder_resource = Resource()
+        responder_resource.putChild(b'.well-known', well_known)
+        responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain'))
+
+        srv = server.Site(responder_resource)
+
+        listeners = []
+
+        for host in self.hs.config.acme_bind_addresses:
+            logger.info(
+                "Listening for ACME requests on %s:%s", host, self.hs.config.acme_port
+            )
+            endpoint = serverFromString(
+                self.reactor, "tcp:%s:interface=%s" % (self.hs.config.acme_port, host)
+            )
+            listeners.append(endpoint.listen(srv))
+
+        # Make sure we are registered to the ACME server. There's no public API
+        # for this, it is usually triggered by startService, but since we don't
+        # want it to control where we save the certificates, we have to reach in
+        # and trigger the registration machinery ourselves.
+        self._issuer._registered = False
+        yield self._issuer._ensure_registered()
+
+        # Return a Deferred that will fire when all the servers have started up.
+        yield defer.DeferredList(listeners, fireOnOneErrback=True, consumeErrors=True)
+
+    @defer.inlineCallbacks
+    def provision_certificate(self):
+
+        logger.warning("Reprovisioning %s", self.hs.hostname)
+
+        try:
+            yield self._issuer.issue_cert(self.hs.hostname)
+        except Exception:
+            logger.exception("Fail!")
+            raise
+        logger.warning("Reprovisioned %s, saving.", self.hs.hostname)
+        cert_chain = self._store.certs[self.hs.hostname]
+
+        try:
+            with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:
+                for x in cert_chain:
+                    if x.startswith(b"-----BEGIN RSA PRIVATE KEY-----"):
+                        private_key_file.write(x)
+
+            with open(self.hs.config.tls_certificate_file, "wb") as certificate_file:
+                for x in cert_chain:
+                    if x.startswith(b"-----BEGIN CERTIFICATE-----"):
+                        certificate_file.write(x)
+        except Exception:
+            logger.exception("Failed saving!")
+            raise
+
+        defer.returnValue(True)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 882e844eb1..756721e304 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -79,6 +79,10 @@ CONDITIONAL_REQUIREMENTS = {
     # ConsentResource uses select_autoescape, which arrived in jinja 2.9
     "resources.consent": ["Jinja2>=2.9"],
 
+    # ACME support is required to provision TLS certificates from authorities
+    # that use the protocol, such as Let's Encrypt.
+    "acme": ["txacme>=0.9.2"],
+
     "saml2": ["pysaml2>=4.5.0"],
     "url_preview": ["lxml>=3.5.0"],
     "test": ["mock>=2.0"],
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index aec0c6b075..14025cd219 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -309,22 +309,16 @@ class RegisterRestServlet(RestServlet):
                 assigned_user_id=registered_user_id,
             )
 
-        # Only give msisdn flows if the x_show_msisdn flag is given:
-        # this is a hack to work around the fact that clients were shipped
-        # that use fallback registration if they see any flows that they don't
-        # recognise, which means we break registration for these clients if we
-        # advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot
-        # Android <=0.6.9 have fallen below an acceptable threshold, this
-        # parameter should go away and we should always advertise msisdn flows.
-        show_msisdn = False
-        if 'x_show_msisdn' in body and body['x_show_msisdn']:
-            show_msisdn = True
-
         # FIXME: need a better error than "no auth flow found" for scenarios
         # where we required 3PID for registration but the user didn't give one
         require_email = 'email' in self.hs.config.registrations_require_3pid
         require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
 
+        show_msisdn = True
+        if self.hs.config.disable_msisdn_registration:
+            show_msisdn = False
+            require_msisdn = False
+
         flows = []
         if self.hs.config.enable_registration_captcha:
             # only support 3PIDless registration if no 3PIDs are required
diff --git a/synapse/server.py b/synapse/server.py
index 9985687b95..c8914302cf 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -46,6 +46,7 @@ from synapse.federation.transport.client import TransportLayerClient
 from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
 from synapse.groups.groups_server import GroupsServerHandler
 from synapse.handlers import Handlers
+from synapse.handlers.acme import AcmeHandler
 from synapse.handlers.appservice import ApplicationServicesHandler
 from synapse.handlers.auth import AuthHandler, MacaroonGenerator
 from synapse.handlers.deactivate_account import DeactivateAccountHandler
@@ -129,6 +130,7 @@ class HomeServer(object):
         'sync_handler',
         'typing_handler',
         'room_list_handler',
+        'acme_handler',
         'auth_handler',
         'device_handler',
         'e2e_keys_handler',
@@ -310,6 +312,9 @@ class HomeServer(object):
     def build_e2e_room_keys_handler(self):
         return E2eRoomKeysHandler(self)
 
+    def build_acme_handler(self):
+        return AcmeHandler(self)
+
     def build_application_service_api(self):
         return ApplicationServiceApi(self)
 
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 5d548f250a..b228a20ac2 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -110,8 +110,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
 
     @defer.inlineCallbacks
     def _remove_user_ip_dupes(self, progress, batch_size):
+        # This works function works by scanning the user_ips table in batches
+        # based on `last_seen`. For each row in a batch it searches the rest of
+        # the table to see if there are any duplicates, if there are then they
+        # are removed and replaced with a suitable row.
 
-        last_seen_progress = progress.get("last_seen", 0)
+        # Fetch the start of the batch
+        begin_last_seen = progress.get("last_seen", 0)
 
         def get_last_seen(txn):
             txn.execute(
@@ -122,29 +127,28 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                 LIMIT 1
                 OFFSET ?
                 """,
-                (last_seen_progress, batch_size)
+                (begin_last_seen, batch_size)
             )
-            results = txn.fetchone()
-            return results
-
-        # Get a last seen that's sufficiently far away enough from the last one
-        last_seen = yield self.runInteraction(
+            row = txn.fetchone()
+            if row:
+                return row[0]
+            else:
+                return None
+
+        # Get a last seen that has roughly `batch_size` since `begin_last_seen`
+        end_last_seen = yield self.runInteraction(
             "user_ips_dups_get_last_seen", get_last_seen
         )
 
-        if not last_seen:
-            # If we get a None then we're reaching the end and just need to
-            # delete the last batch.
-            last = True
+        # If it returns None, then we're processing the last batch
+        last = end_last_seen is None
 
-            # We fake not having an upper bound by using a future date, by
-            # just multiplying the current time by two....
-            last_seen = int(self.clock.time_msec()) * 2
-        else:
-            last = False
-            last_seen = last_seen[0]
+        logger.info(
+            "Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s",
+            begin_last_seen, end_last_seen,
+        )
 
-        def remove(txn, last_seen_progress, last_seen):
+        def remove(txn):
             # This works by looking at all entries in the given time span, and
             # then for each (user_id, access_token, ip) tuple in that range
             # checking for any duplicates in the rest of the table (via a join).
@@ -153,6 +157,16 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
             # all other duplicates.
             # It is efficient due to the existence of (user_id, access_token,
             # ip) and (last_seen) indices.
+
+            # Define the search space, which requires handling the last batch in
+            # a different way
+            if last:
+                clause = "? <= last_seen"
+                args = (begin_last_seen,)
+            else:
+                clause = "? <= last_seen AND last_seen < ?"
+                args = (begin_last_seen, end_last_seen)
+
             txn.execute(
                 """
                 SELECT user_id, access_token, ip,
@@ -160,13 +174,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                 FROM (
                     SELECT user_id, access_token, ip
                     FROM user_ips
-                    WHERE ? <= last_seen AND last_seen < ?
-                    ORDER BY last_seen
+                    WHERE {}
                 ) c
                 INNER JOIN user_ips USING (user_id, access_token, ip)
                 GROUP BY user_id, access_token, ip
-                HAVING count(*) > 1""",
-                (last_seen_progress, last_seen)
+                HAVING count(*) > 1
+                """.format(clause),
+                args
             )
             res = txn.fetchall()
 
@@ -194,12 +208,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                 )
 
             self._background_update_progress_txn(
-                txn, "user_ips_remove_dupes", {"last_seen": last_seen}
+                txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
             )
 
-        yield self.runInteraction(
-            "user_ips_dups_remove", remove, last_seen_progress, last_seen
-        )
+        yield self.runInteraction("user_ips_dups_remove", remove)
+
         if last:
             yield self._end_background_update("user_ips_remove_dupes")