summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/emailconfig.py2
-rw-r--r--synapse/config/key.py34
-rw-r--r--synapse/crypto/keyring.py11
-rw-r--r--synapse/federation/transport/client.py46
-rw-r--r--synapse/federation/transport/server.py36
-rw-r--r--synapse/handlers/profile.py2
-rw-r--r--synapse/handlers/room_list.py29
-rw-r--r--synapse/http/federation/matrix_federation_agent.py364
-rw-r--r--synapse/http/federation/srv_resolver.py61
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py28
-rw-r--r--synapse/rest/well_known.py2
-rw-r--r--synapse/storage/prepare_database.py24
12 files changed, 381 insertions, 258 deletions
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 36d01a10af..f83c05df44 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -115,7 +115,7 @@ class EmailConfig(Config):
                     missing.append("email." + k)
 
             if config.get("public_baseurl") is None:
-                missing.append("public_base_url")
+                missing.append("public_baseurl")
 
             if len(missing) > 0:
                 raise RuntimeError(
diff --git a/synapse/config/key.py b/synapse/config/key.py
index fe8386985c..ba2199bceb 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -76,7 +76,7 @@ class KeyConfig(Config):
                     config_dir_path, config["server_name"] + ".signing.key"
                 )
 
-            self.signing_key = self.read_signing_key(signing_key_path)
+            self.signing_key = self.read_signing_keys(signing_key_path, "signing_key")
 
         self.old_signing_keys = self.read_old_signing_keys(
             config.get("old_signing_keys", {})
@@ -85,6 +85,14 @@ class KeyConfig(Config):
             config.get("key_refresh_interval", "1d")
         )
 
+        key_server_signing_keys_path = config.get("key_server_signing_keys_path")
+        if key_server_signing_keys_path:
+            self.key_server_signing_keys = self.read_signing_keys(
+                key_server_signing_keys_path, "key_server_signing_keys_path"
+            )
+        else:
+            self.key_server_signing_keys = list(self.signing_key)
+
         # if neither trusted_key_servers nor perspectives are given, use the default.
         if "perspectives" not in config and "trusted_key_servers" not in config:
             key_servers = [{"server_name": "matrix.org"}]
@@ -210,16 +218,34 @@ class KeyConfig(Config):
         #
         #trusted_key_servers:
         #  - server_name: "matrix.org"
+        #
+
+        # The signing keys to use when acting as a trusted key server. If not specified
+        # defaults to the server signing key.
+        #
+        # Can contain multiple keys, one per line.
+        #
+        #key_server_signing_keys_path: "key_server_signing_keys.key"
         """
             % locals()
         )
 
-    def read_signing_key(self, signing_key_path):
-        signing_keys = self.read_file(signing_key_path, "signing_key")
+    def read_signing_keys(self, signing_key_path, name):
+        """Read the signing keys in the given path.
+
+        Args:
+            signing_key_path (str)
+            name (str): Associated config key name
+
+        Returns:
+            list[SigningKey]
+        """
+
+        signing_keys = self.read_file(signing_key_path, name)
         try:
             return read_signing_keys(signing_keys.splitlines(True))
         except Exception as e:
-            raise ConfigError("Error reading signing_key: %s" % (str(e)))
+            raise ConfigError("Error reading %s: %s" % (name, str(e)))
 
     def read_old_signing_keys(self, old_signing_keys):
         keys = {}
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 654accc843..7cfad192e8 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -29,7 +29,6 @@ from signedjson.key import (
 from signedjson.sign import (
     SignatureVerifyException,
     encode_canonical_json,
-    sign_json,
     signature_ids,
     verify_signed_json,
 )
@@ -539,13 +538,7 @@ class BaseV2KeyFetcher(object):
                     verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
                 )
 
-        # re-sign the json with our own key, so that it is ready if we are asked to
-        # give it out as a notary server
-        signed_key_json = sign_json(
-            response_json, self.config.server_name, self.config.signing_key[0]
-        )
-
-        signed_key_json_bytes = encode_canonical_json(signed_key_json)
+        key_json_bytes = encode_canonical_json(response_json)
 
         yield make_deferred_yieldable(
             defer.gatherResults(
@@ -557,7 +550,7 @@ class BaseV2KeyFetcher(object):
                         from_server=from_server,
                         ts_now_ms=time_added_ms,
                         ts_expires_ms=ts_valid_until_ms,
-                        key_json_bytes=signed_key_json_bytes,
+                        key_json_bytes=key_json_bytes,
                     )
                     for key_id in verify_keys
                 ],
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 0cea0d2a10..482a101c09 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -327,21 +327,37 @@ class TransportLayerClient(object):
         include_all_networks=False,
         third_party_instance_id=None,
     ):
-        path = _create_v1_path("/publicRooms")
-
-        args = {"include_all_networks": "true" if include_all_networks else "false"}
-        if third_party_instance_id:
-            args["third_party_instance_id"] = (third_party_instance_id,)
-        if limit:
-            args["limit"] = [str(limit)]
-        if since_token:
-            args["since"] = [since_token]
-
-        # TODO(erikj): Actually send the search_filter across federation.
-
-        response = yield self.client.get_json(
-            destination=remote_server, path=path, args=args, ignore_backoff=True
-        )
+        if search_filter:
+            # this uses MSC2197 (Search Filtering over Federation)
+            path = _create_v1_path("/publicRooms")
+
+            data = {"include_all_networks": "true" if include_all_networks else "false"}
+            if third_party_instance_id:
+                data["third_party_instance_id"] = third_party_instance_id
+            if limit:
+                data["limit"] = str(limit)
+            if since_token:
+                data["since"] = since_token
+
+            data["filter"] = search_filter
+
+            response = yield self.client.post_json(
+                destination=remote_server, path=path, data=data, ignore_backoff=True
+            )
+        else:
+            path = _create_v1_path("/publicRooms")
+
+            args = {"include_all_networks": "true" if include_all_networks else "false"}
+            if third_party_instance_id:
+                args["third_party_instance_id"] = (third_party_instance_id,)
+            if limit:
+                args["limit"] = [str(limit)]
+            if since_token:
+                args["since"] = [since_token]
+
+            response = yield self.client.get_json(
+                destination=remote_server, path=path, args=args, ignore_backoff=True
+            )
 
         return response
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index dc53b4b170..f9930b6460 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -770,6 +770,42 @@ class PublicRoomList(BaseFederationServlet):
         )
         return 200, data
 
+    async def on_POST(self, origin, content, query):
+        # This implements MSC2197 (Search Filtering over Federation)
+        if not self.allow_access:
+            raise FederationDeniedError(origin)
+
+        limit = int(content.get("limit", 100))
+        since_token = content.get("since", None)
+        search_filter = content.get("filter", None)
+
+        include_all_networks = content.get("include_all_networks", False)
+        third_party_instance_id = content.get("third_party_instance_id", None)
+
+        if include_all_networks:
+            network_tuple = None
+            if third_party_instance_id is not None:
+                raise SynapseError(
+                    400, "Can't use include_all_networks with an explicit network"
+                )
+        elif third_party_instance_id is None:
+            network_tuple = ThirdPartyInstanceID(None, None)
+        else:
+            network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
+
+        if search_filter is None:
+            logger.warning("Nonefilter")
+
+        data = await self.handler.get_local_public_room_list(
+            limit=limit,
+            since_token=since_token,
+            search_filter=search_filter,
+            network_tuple=network_tuple,
+            from_federation=True,
+        )
+
+        return 200, data
+
 
 class FederationVersionServlet(BaseFederationServlet):
     PATH = "/version"
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 2cc237e6a5..8690f69d45 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -34,7 +34,7 @@ from ._base import BaseHandler
 
 logger = logging.getLogger(__name__)
 
-MAX_DISPLAYNAME_LEN = 100
+MAX_DISPLAYNAME_LEN = 256
 MAX_AVATAR_URL_LEN = 1000
 
 
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index e9094ad02b..a7e55f00e5 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -25,6 +25,7 @@ from unpaddedbase64 import decode_base64, encode_base64
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.errors import Codes, HttpResponseException
 from synapse.types import ThirdPartyInstanceID
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.descriptors import cachedInlineCallbacks
@@ -485,7 +486,33 @@ class RoomListHandler(BaseHandler):
             return {"chunk": [], "total_room_count_estimate": 0}
 
         if search_filter:
-            # We currently don't support searching across federation, so we have
+            # Searching across federation is defined in MSC2197.
+            # However, the remote homeserver may or may not actually support it.
+            # So we first try an MSC2197 remote-filtered search, then fall back
+            # to a locally-filtered search if we must.
+
+            try:
+                res = yield self._get_remote_list_cached(
+                    server_name,
+                    limit=limit,
+                    since_token=since_token,
+                    include_all_networks=include_all_networks,
+                    third_party_instance_id=third_party_instance_id,
+                    search_filter=search_filter,
+                )
+                return res
+            except HttpResponseException as hre:
+                syn_err = hre.to_synapse_error()
+                if hre.code in (404, 405) or syn_err.errcode in (
+                    Codes.UNRECOGNIZED,
+                    Codes.NOT_FOUND,
+                ):
+                    logger.debug("Falling back to locally-filtered /publicRooms")
+                else:
+                    raise  # Not an error that should trigger a fallback.
+
+            # if we reach this point, then we fall back to the situation where
+            # we currently don't support searching across federation, so we have
             # to do it manually without pagination
             limit = None
             since_token = None
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 64f62aaeec..feae7de5be 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -14,21 +14,21 @@
 # limitations under the License.
 
 import logging
+import urllib
 
-import attr
-from netaddr import IPAddress
+from netaddr import AddrFormatError, IPAddress
 from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.web.client import URI, Agent, HTTPConnectionPool
+from twisted.web.client import Agent, HTTPConnectionPool
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent
+from twisted.web.iweb import IAgent, IAgentEndpointFactory
 
-from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
+from synapse.http.federation.srv_resolver import Server, SrvResolver
 from synapse.http.federation.well_known_resolver import WellKnownResolver
-from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
@@ -36,8 +36,9 @@ logger = logging.getLogger(__name__)
 
 @implementer(IAgent)
 class MatrixFederationAgent(object):
-    """An Agent-like thing which provides a `request` method which will look up a matrix
-    server and send an HTTP request to it.
+    """An Agent-like thing which provides a `request` method which correctly
+    handles resolving matrix server names when using matrix://. Handles standard
+    https URIs as normal.
 
     Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
 
@@ -65,17 +66,19 @@ class MatrixFederationAgent(object):
     ):
         self._reactor = reactor
         self._clock = Clock(reactor)
-
-        self._tls_client_options_factory = tls_client_options_factory
-        if _srv_resolver is None:
-            _srv_resolver = SrvResolver()
-        self._srv_resolver = _srv_resolver
-
         self._pool = HTTPConnectionPool(reactor)
         self._pool.retryAutomatically = False
         self._pool.maxPersistentPerHost = 5
         self._pool.cachedConnectionTimeout = 2 * 60
 
+        self._agent = Agent.usingEndpointFactory(
+            self._reactor,
+            MatrixHostnameEndpointFactory(
+                reactor, tls_client_options_factory, _srv_resolver
+            ),
+            pool=self._pool,
+        )
+
         if _well_known_resolver is None:
             _well_known_resolver = WellKnownResolver(
                 self._reactor,
@@ -93,19 +96,15 @@ class MatrixFederationAgent(object):
         """
         Args:
             method (bytes): HTTP method: GET/POST/etc
-
             uri (bytes): Absolute URI to be retrieved
-
             headers (twisted.web.http_headers.Headers|None):
                 HTTP headers to send with the request, or None to
                 send no extra headers.
-
             bodyProducer (twisted.web.iweb.IBodyProducer|None):
                 An object which can generate bytes to make up the
                 body of this request (for example, the properly encoded contents of
                 a file for a file upload).  Or None if the request is to have
                 no body.
-
         Returns:
             Deferred[twisted.web.iweb.IResponse]:
                 fires when the header of the response has been received (regardless of the
@@ -113,210 +112,207 @@ class MatrixFederationAgent(object):
                 response from being received (including problems that prevent the request
                 from being sent).
         """
-        parsed_uri = URI.fromBytes(uri, defaultPort=-1)
-        res = yield self._route_matrix_uri(parsed_uri)
+        # We use urlparse as that will set `port` to None if there is no
+        # explicit port.
+        parsed_uri = urllib.parse.urlparse(uri)
 
-        # set up the TLS connection params
+        # If this is a matrix:// URI check if the server has delegated matrix
+        # traffic using well-known delegation.
         #
-        # XXX disabling TLS is really only supported here for the benefit of the
-        # unit tests. We should make the UTs cope with TLS rather than having to make
-        # the code support the unit tests.
-        if self._tls_client_options_factory is None:
-            tls_options = None
-        else:
-            tls_options = self._tls_client_options_factory.get_options(
-                res.tls_server_name.decode("ascii")
+        # We have to do this here and not in the endpoint as we need to rewrite
+        # the host header with the delegated server name.
+        delegated_server = None
+        if (
+            parsed_uri.scheme == b"matrix"
+            and not _is_ip_literal(parsed_uri.hostname)
+            and not parsed_uri.port
+        ):
+            well_known_result = yield self._well_known_resolver.get_well_known(
+                parsed_uri.hostname
             )
+            delegated_server = well_known_result.delegated_server
+
+        if delegated_server:
+            # Ok, the server has delegated matrix traffic to somewhere else, so
+            # lets rewrite the URL to replace the server with the delegated
+            # server name.
+            uri = urllib.parse.urlunparse(
+                (
+                    parsed_uri.scheme,
+                    delegated_server,
+                    parsed_uri.path,
+                    parsed_uri.params,
+                    parsed_uri.query,
+                    parsed_uri.fragment,
+                )
+            )
+            parsed_uri = urllib.parse.urlparse(uri)
 
-        # make sure that the Host header is set correctly
+        # We need to make sure the host header is set to the netloc of the
+        # server.
         if headers is None:
             headers = Headers()
         else:
             headers = headers.copy()
 
         if not headers.hasHeader(b"host"):
-            headers.addRawHeader(b"host", res.host_header)
+            headers.addRawHeader(b"host", parsed_uri.netloc)
 
-        class EndpointFactory(object):
-            @staticmethod
-            def endpointForURI(_uri):
-                ep = LoggingHostnameEndpoint(
-                    self._reactor, res.target_host, res.target_port
-                )
-                if tls_options is not None:
-                    ep = wrapClientTLS(tls_options, ep)
-                return ep
-
-        agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
         res = yield make_deferred_yieldable(
-            agent.request(method, uri, headers, bodyProducer)
+            self._agent.request(method, uri, headers, bodyProducer)
         )
+
         return res
 
-    @defer.inlineCallbacks
-    def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
-        """Helper for `request`: determine the routing for a Matrix URI
 
-        Args:
-            parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
-                parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
-                if there is no explicit port given.
+@implementer(IAgentEndpointFactory)
+class MatrixHostnameEndpointFactory(object):
+    """Factory for MatrixHostnameEndpoint for parsing to an Agent.
+    """
 
-            lookup_well_known (bool): True if we should look up the .well-known file if
-                there is no SRV record.
+    def __init__(self, reactor, tls_client_options_factory, srv_resolver):
+        self._reactor = reactor
+        self._tls_client_options_factory = tls_client_options_factory
 
-        Returns:
-            Deferred[_RoutingResult]
-        """
-        # check for an IP literal
-        try:
-            ip_address = IPAddress(parsed_uri.host.decode("ascii"))
-        except Exception:
-            # not an IP address
-            ip_address = None
-
-        if ip_address:
-            port = parsed_uri.port
-            if port == -1:
-                port = 8448
-            return _RoutingResult(
-                host_header=parsed_uri.netloc,
-                tls_server_name=parsed_uri.host,
-                target_host=parsed_uri.host,
-                target_port=port,
-            )
+        if srv_resolver is None:
+            srv_resolver = SrvResolver()
 
-        if parsed_uri.port != -1:
-            # there is an explicit port
-            return _RoutingResult(
-                host_header=parsed_uri.netloc,
-                tls_server_name=parsed_uri.host,
-                target_host=parsed_uri.host,
-                target_port=parsed_uri.port,
-            )
+        self._srv_resolver = srv_resolver
 
-        if lookup_well_known:
-            # try a .well-known lookup
-            well_known_result = yield self._well_known_resolver.get_well_known(
-                parsed_uri.host
-            )
-            well_known_server = well_known_result.delegated_server
-
-            if well_known_server:
-                # if we found a .well-known, start again, but don't do another
-                # .well-known lookup.
-
-                # parse the server name in the .well-known response into host/port.
-                # (This code is lifted from twisted.web.client.URI.fromBytes).
-                if b":" in well_known_server:
-                    well_known_host, well_known_port = well_known_server.rsplit(b":", 1)
-                    try:
-                        well_known_port = int(well_known_port)
-                    except ValueError:
-                        # the part after the colon could not be parsed as an int
-                        # - we assume it is an IPv6 literal with no port (the closing
-                        # ']' stops it being parsed as an int)
-                        well_known_host, well_known_port = well_known_server, -1
-                else:
-                    well_known_host, well_known_port = well_known_server, -1
-
-                new_uri = URI(
-                    scheme=parsed_uri.scheme,
-                    netloc=well_known_server,
-                    host=well_known_host,
-                    port=well_known_port,
-                    path=parsed_uri.path,
-                    params=parsed_uri.params,
-                    query=parsed_uri.query,
-                    fragment=parsed_uri.fragment,
-                )
+    def endpointForURI(self, parsed_uri):
+        return MatrixHostnameEndpoint(
+            self._reactor,
+            self._tls_client_options_factory,
+            self._srv_resolver,
+            parsed_uri,
+        )
 
-                res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
-                return res
-
-        # try a SRV lookup
-        service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
-        server_list = yield self._srv_resolver.resolve_service(service_name)
-
-        if not server_list:
-            target_host = parsed_uri.host
-            port = 8448
-            logger.debug(
-                "No SRV record for %s, using %s:%i",
-                parsed_uri.host.decode("ascii"),
-                target_host.decode("ascii"),
-                port,
-            )
+
+@implementer(IStreamClientEndpoint)
+class MatrixHostnameEndpoint(object):
+    """An endpoint that resolves matrix:// URLs using Matrix server name
+    resolution (i.e. via SRV). Does not check for well-known delegation.
+
+    Args:
+        reactor (IReactor)
+        tls_client_options_factory (ClientTLSOptionsFactory|None):
+            factory to use for fetching client tls options, or none to disable TLS.
+        srv_resolver (SrvResolver): The SRV resolver to use
+        parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting
+            to connect to.
+    """
+
+    def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
+        self._reactor = reactor
+
+        self._parsed_uri = parsed_uri
+
+        # set up the TLS connection params
+        #
+        # XXX disabling TLS is really only supported here for the benefit of the
+        # unit tests. We should make the UTs cope with TLS rather than having to make
+        # the code support the unit tests.
+
+        if tls_client_options_factory is None:
+            self._tls_options = None
         else:
-            target_host, port = pick_server_from_list(server_list)
-            logger.debug(
-                "Picked %s:%i from SRV records for %s",
-                target_host.decode("ascii"),
-                port,
-                parsed_uri.host.decode("ascii"),
+            self._tls_options = tls_client_options_factory.get_options(
+                self._parsed_uri.host.decode("ascii")
             )
 
-        return _RoutingResult(
-            host_header=parsed_uri.netloc,
-            tls_server_name=parsed_uri.host,
-            target_host=target_host,
-            target_port=port,
-        )
+        self._srv_resolver = srv_resolver
 
+    def connect(self, protocol_factory):
+        """Implements IStreamClientEndpoint interface
+        """
 
-@implementer(IStreamClientEndpoint)
-class LoggingHostnameEndpoint(object):
-    """A wrapper for HostnameEndpint which logs when it connects"""
+        return run_in_background(self._do_connect, protocol_factory)
 
-    def __init__(self, reactor, host, port, *args, **kwargs):
-        self.host = host
-        self.port = port
-        self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)
+    @defer.inlineCallbacks
+    def _do_connect(self, protocol_factory):
+        first_exception = None
+
+        server_list = yield self._resolve_server()
+
+        for server in server_list:
+            host = server.host
+            port = server.port
+
+            try:
+                logger.info("Connecting to %s:%i", host.decode("ascii"), port)
+                endpoint = HostnameEndpoint(self._reactor, host, port)
+                if self._tls_options:
+                    endpoint = wrapClientTLS(self._tls_options, endpoint)
+                result = yield make_deferred_yieldable(
+                    endpoint.connect(protocol_factory)
+                )
 
-    def connect(self, protocol_factory):
-        logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port)
-        return self.ep.connect(protocol_factory)
+                return result
+            except Exception as e:
+                logger.info(
+                    "Failed to connect to %s:%i: %s", host.decode("ascii"), port, e
+                )
+                if not first_exception:
+                    first_exception = e
 
+        # We return the first failure because that's probably the most interesting.
+        if first_exception:
+            raise first_exception
 
-@attr.s
-class _RoutingResult(object):
-    """The result returned by `_route_matrix_uri`.
+        # This shouldn't happen as we should always have at least one host/port
+        # to try and if that doesn't work then we'll have an exception.
+        raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
 
-    Contains the parameters needed to direct a federation connection to a particular
-    server.
+    @defer.inlineCallbacks
+    def _resolve_server(self):
+        """Resolves the server name to a list of hosts and ports to attempt to
+        connect to.
 
-    Where a SRV record points to several servers, this object contains a single server
-    chosen from the list.
-    """
+        Returns:
+            Deferred[list[Server]]
+        """
 
-    host_header = attr.ib()
-    """
-    The value we should assign to the Host header (host:port from the matrix
-    URI, or .well-known).
+        if self._parsed_uri.scheme != b"matrix":
+            return [Server(host=self._parsed_uri.host, port=self._parsed_uri.port)]
 
-    :type: bytes
-    """
+        # Note: We don't do well-known lookup as that needs to have happened
+        # before now, due to needing to rewrite the Host header of the HTTP
+        # request.
 
-    tls_server_name = attr.ib()
-    """
-    The server name we should set in the SNI (typically host, without port, from the
-    matrix URI or .well-known)
+        # We reparse the URI so that defaultPort is -1 rather than 80
+        parsed_uri = urllib.parse.urlparse(self._parsed_uri.toBytes())
 
-    :type: bytes
-    """
+        host = parsed_uri.hostname
+        port = parsed_uri.port
 
-    target_host = attr.ib()
-    """
-    The hostname (or IP literal) we should route the TCP connection to (the target of the
-    SRV record, or the hostname from the URL/.well-known)
+        # If there is an explicit port or the host is an IP address we bypass
+        # SRV lookups and just use the given host/port.
+        if port or _is_ip_literal(host):
+            return [Server(host, port or 8448)]
 
-    :type: bytes
-    """
+        server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
 
-    target_port = attr.ib()
-    """
-    The port we should route the TCP connection to (the target of the SRV record, or
-    the port from the URL/.well-known, or 8448)
+        if server_list:
+            return server_list
+
+        # No SRV records, so we fallback to host and 8448
+        return [Server(host, 8448)]
+
+
+def _is_ip_literal(host):
+    """Test if the given host name is either an IPv4 or IPv6 literal.
 
-    :type: int
+    Args:
+        host (bytes)
+
+    Returns:
+        bool
     """
+
+    host = host.decode("ascii")
+
+    try:
+        IPAddress(host)
+        return True
+    except AddrFormatError:
+        return False
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index b32188766d..3fe4ffb9e5 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
 SERVER_CACHE = {}
 
 
-@attr.s
+@attr.s(slots=True, frozen=True)
 class Server(object):
     """
     Our record of an individual server which can be tried to reach a destination.
@@ -53,34 +53,47 @@ class Server(object):
     expires = attr.ib(default=0)
 
 
-def pick_server_from_list(server_list):
-    """Randomly choose a server from the server list
+def _sort_server_list(server_list):
+    """Given a list of SRV records sort them into priority order and shuffle
+    each priority with the given weight.
+    """
+    priority_map = {}
 
-    Args:
-        server_list (list[Server]): list of candidate servers
+    for server in server_list:
+        priority_map.setdefault(server.priority, []).append(server)
 
-    Returns:
-        Tuple[bytes, int]: (host, port) pair for the chosen server
-    """
-    if not server_list:
-        raise RuntimeError("pick_server_from_list called with empty list")
+    results = []
+    for priority in sorted(priority_map):
+        servers = priority_map[priority]
+
+        # This algorithms roughly follows the algorithm described in RFC2782,
+        # changed to remove an off-by-one error.
+        #
+        # N.B. Weights can be zero, which means that they should be picked
+        # rarely.
+
+        total_weight = sum(s.weight for s in servers)
+
+        # Total weight can become zero if there are only zero weight servers
+        # left, which we handle by just shuffling and appending to the results.
+        while servers and total_weight:
+            target_weight = random.randint(1, total_weight)
 
-    # TODO: currently we only use the lowest-priority servers. We should maintain a
-    # cache of servers known to be "down" and filter them out
+            for s in servers:
+                target_weight -= s.weight
 
-    min_priority = min(s.priority for s in server_list)
-    eligible_servers = list(s for s in server_list if s.priority == min_priority)
-    total_weight = sum(s.weight for s in eligible_servers)
-    target_weight = random.randint(0, total_weight)
+                if target_weight <= 0:
+                    break
 
-    for s in eligible_servers:
-        target_weight -= s.weight
+            results.append(s)
+            servers.remove(s)
+            total_weight -= s.weight
 
-        if target_weight <= 0:
-            return s.host, s.port
+        if servers:
+            random.shuffle(servers)
+            results.extend(servers)
 
-    # this should be impossible.
-    raise RuntimeError("pick_server_from_list got to end of eligible server list.")
+    return results
 
 
 class SrvResolver(object):
@@ -120,7 +133,7 @@ class SrvResolver(object):
         if cache_entry:
             if all(s.expires > now for s in cache_entry):
                 servers = list(cache_entry)
-                return servers
+                return _sort_server_list(servers)
 
         try:
             answers, _, _ = yield make_deferred_yieldable(
@@ -169,4 +182,4 @@ class SrvResolver(object):
             )
 
         self._cache[service_name] = list(servers)
-        return servers
+        return _sort_server_list(servers)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 031a316693..55580bc59e 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,7 +13,9 @@
 # limitations under the License.
 
 import logging
-from io import BytesIO
+
+from canonicaljson import encode_canonical_json, json
+from signedjson.sign import sign_json
 
 from twisted.internet import defer
 
@@ -95,6 +97,7 @@ class RemoteKey(DirectServeResource):
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
         self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+        self.config = hs.config
 
     @wrap_json_request_handler
     async def _async_render_GET(self, request):
@@ -214,15 +217,14 @@ class RemoteKey(DirectServeResource):
             yield self.fetcher.get_keys(cache_misses)
             yield self.query_keys(request, query, query_remote_on_cache_miss=False)
         else:
-            result_io = BytesIO()
-            result_io.write(b'{"server_keys":')
-            sep = b"["
-            for json_bytes in json_results:
-                result_io.write(sep)
-                result_io.write(json_bytes)
-                sep = b","
-            if sep == b"[":
-                result_io.write(sep)
-            result_io.write(b"]}")
-
-            respond_with_json_bytes(request, 200, result_io.getvalue())
+            signed_keys = []
+            for key_json in json_results:
+                key_json = json.loads(key_json)
+                for signing_key in self.config.key_server_signing_keys:
+                    key_json = sign_json(key_json, self.config.server_name, signing_key)
+
+                signed_keys.append(key_json)
+
+            results = {"server_keys": signed_keys}
+
+            respond_with_json_bytes(request, 200, encode_canonical_json(results))
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 5e8fda4b65..20177b44e7 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -34,7 +34,7 @@ class WellKnownBuilder(object):
         self._config = hs.config
 
     def get_well_known(self):
-        # if we don't have a public_base_url, we can't help much here.
+        # if we don't have a public_baseurl, we can't help much here.
         if self._config.public_baseurl is None:
             return None
 
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index d20eacda59..e96eed8a6d 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -238,6 +238,13 @@ def _upgrade_existing_database(
 
     logger.debug("applied_delta_files: %s", applied_delta_files)
 
+    if isinstance(database_engine, PostgresEngine):
+        specific_engine_extension = ".postgres"
+    else:
+        specific_engine_extension = ".sqlite"
+
+    specific_engine_extensions = (".sqlite", ".postgres")
+
     for v in range(start_ver, SCHEMA_VERSION + 1):
         logger.info("Upgrading schema to v%d", v)
 
@@ -274,15 +281,22 @@ def _upgrade_existing_database(
                 # Sometimes .pyc files turn up anyway even though we've
                 # disabled their generation; e.g. from distribution package
                 # installers. Silently skip it
-                pass
+                continue
             elif ext == ".sql":
                 # A plain old .sql file, just read and execute it
                 logger.info("Applying schema %s", relative_path)
                 executescript(cur, absolute_path)
+            elif ext == specific_engine_extension and root_name.endswith(".sql"):
+                # A .sql file specific to our engine; just read and execute it
+                logger.info("Applying engine-specific schema %s", relative_path)
+                executescript(cur, absolute_path)
+            elif ext in specific_engine_extensions and root_name.endswith(".sql"):
+                # A .sql file for a different engine; skip it.
+                continue
             else:
                 # Not a valid delta file.
-                logger.warn(
-                    "Found directory entry that did not end in .py or" " .sql: %s",
+                logger.warning(
+                    "Found directory entry that did not end in .py or .sql: %s",
                     relative_path,
                 )
                 continue
@@ -290,7 +304,7 @@ def _upgrade_existing_database(
             # Mark as done.
             cur.execute(
                 database_engine.convert_param_style(
-                    "INSERT INTO applied_schema_deltas (version, file)" " VALUES (?,?)"
+                    "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)"
                 ),
                 (v, relative_path),
             )
@@ -298,7 +312,7 @@ def _upgrade_existing_database(
             cur.execute("DELETE FROM schema_version")
             cur.execute(
                 database_engine.convert_param_style(
-                    "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
+                    "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
                 ),
                 (v, True),
             )