diff --git a/changelog.d/5859.feature b/changelog.d/5859.feature
new file mode 100644
index 0000000000..52df7fc81b
--- /dev/null
+++ b/changelog.d/5859.feature
@@ -0,0 +1 @@
+Add unstable support for MSC2197 (filtered search requests over federation), in order to allow upcoming room directory query performance improvements.
diff --git a/changelog.d/5864.feature b/changelog.d/5864.feature
new file mode 100644
index 0000000000..40ac11db64
--- /dev/null
+++ b/changelog.d/5864.feature
@@ -0,0 +1 @@
+Correctly retry all hosts returned from SRV when we fail to connect.
diff --git a/changelog.d/5895.feature b/changelog.d/5895.feature
new file mode 100644
index 0000000000..c394a3772c
--- /dev/null
+++ b/changelog.d/5895.feature
@@ -0,0 +1 @@
+Add config option to sign remote key query responses with a separate key.
diff --git a/changelog.d/5906.feature b/changelog.d/5906.feature
new file mode 100644
index 0000000000..7c789510a6
--- /dev/null
+++ b/changelog.d/5906.feature
@@ -0,0 +1 @@
+Increase max display name size to 256.
diff --git a/changelog.d/5909.misc b/changelog.d/5909.misc
new file mode 100644
index 0000000000..03d0c4367b
--- /dev/null
+++ b/changelog.d/5909.misc
@@ -0,0 +1 @@
+Fix error message which referred to public_base_url instead of public_baseurl. Thanks to @aaronraimist for the fix!
diff --git a/changelog.d/5911.misc b/changelog.d/5911.misc
new file mode 100644
index 0000000000..fe5a8fd59c
--- /dev/null
+++ b/changelog.d/5911.misc
@@ -0,0 +1 @@
+Add support for database engine-specific schema deltas, based on file extension.
\ No newline at end of file
diff --git a/docker/README.md b/docker/README.md
index 46bb9d2d99..d5879c2f2c 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -17,7 +17,7 @@ By default, the image expects a single volume, located at ``/data``, that will h
* the appservices configuration.
You are free to use separate volumes depending on storage endpoints at your
-disposal. For instance, ``/data/media`` coud be stored on a large but low
+disposal. For instance, ``/data/media`` could be stored on a large but low
performance hdd storage while other files could be stored on high performance
endpoints.
@@ -27,8 +27,8 @@ configuration file there. Multiple application services are supported.
## Generating a configuration file
-The first step is to genearte a valid config file. To do this, you can run the
-image with the `generate` commandline option.
+The first step is to generate a valid config file. To do this, you can run the
+image with the `generate` command line option.
You will need to specify values for the `SYNAPSE_SERVER_NAME` and
`SYNAPSE_REPORT_STATS` environment variable, and mount a docker volume to store
@@ -59,7 +59,7 @@ The following environment variables are supported in `generate` mode:
* `SYNAPSE_CONFIG_PATH`: path to the file to be generated. Defaults to
`<SYNAPSE_CONFIG_DIR>/homeserver.yaml`.
* `SYNAPSE_DATA_DIR`: where the generated config will put persistent data
- such as the datatase and media store. Defaults to `/data`.
+ such as the database and media store. Defaults to `/data`.
* `UID`, `GID`: the user id and group id to use for creating the data
directories. Defaults to `991`, `991`.
@@ -115,7 +115,7 @@ not given).
To migrate from a dynamic configuration file to a static one, run the docker
container once with the environment variables set, and `migrate_config`
-commandline option. For example:
+command line option. For example:
```
docker run -it --rm \
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 0c6be30e51..ae1cafc5f3 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1027,6 +1027,14 @@ signing_key_path: "CONFDIR/SERVERNAME.signing.key"
#
#trusted_key_servers:
# - server_name: "matrix.org"
+#
+
+# The signing keys to use when acting as a trusted key server. If not specified
+# defaults to the server signing key.
+#
+# Can contain multiple keys, one per line.
+#
+#key_server_signing_keys_path: "key_server_signing_keys.key"
# Enable SAML2 for registration and login. Uses pysaml2.
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 36d01a10af..f83c05df44 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -115,7 +115,7 @@ class EmailConfig(Config):
missing.append("email." + k)
if config.get("public_baseurl") is None:
- missing.append("public_base_url")
+ missing.append("public_baseurl")
if len(missing) > 0:
raise RuntimeError(
diff --git a/synapse/config/key.py b/synapse/config/key.py
index fe8386985c..ba2199bceb 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -76,7 +76,7 @@ class KeyConfig(Config):
config_dir_path, config["server_name"] + ".signing.key"
)
- self.signing_key = self.read_signing_key(signing_key_path)
+ self.signing_key = self.read_signing_keys(signing_key_path, "signing_key")
self.old_signing_keys = self.read_old_signing_keys(
config.get("old_signing_keys", {})
@@ -85,6 +85,14 @@ class KeyConfig(Config):
config.get("key_refresh_interval", "1d")
)
+ key_server_signing_keys_path = config.get("key_server_signing_keys_path")
+ if key_server_signing_keys_path:
+ self.key_server_signing_keys = self.read_signing_keys(
+ key_server_signing_keys_path, "key_server_signing_keys_path"
+ )
+ else:
+ self.key_server_signing_keys = list(self.signing_key)
+
# if neither trusted_key_servers nor perspectives are given, use the default.
if "perspectives" not in config and "trusted_key_servers" not in config:
key_servers = [{"server_name": "matrix.org"}]
@@ -210,16 +218,34 @@ class KeyConfig(Config):
#
#trusted_key_servers:
# - server_name: "matrix.org"
+ #
+
+ # The signing keys to use when acting as a trusted key server. If not specified
+ # defaults to the server signing key.
+ #
+ # Can contain multiple keys, one per line.
+ #
+ #key_server_signing_keys_path: "key_server_signing_keys.key"
"""
% locals()
)
- def read_signing_key(self, signing_key_path):
- signing_keys = self.read_file(signing_key_path, "signing_key")
+ def read_signing_keys(self, signing_key_path, name):
+ """Read the signing keys in the given path.
+
+ Args:
+ signing_key_path (str)
+ name (str): Associated config key name
+
+ Returns:
+ list[SigningKey]
+ """
+
+ signing_keys = self.read_file(signing_key_path, name)
try:
return read_signing_keys(signing_keys.splitlines(True))
except Exception as e:
- raise ConfigError("Error reading signing_key: %s" % (str(e)))
+ raise ConfigError("Error reading %s: %s" % (name, str(e)))
def read_old_signing_keys(self, old_signing_keys):
keys = {}
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 654accc843..7cfad192e8 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -29,7 +29,6 @@ from signedjson.key import (
from signedjson.sign import (
SignatureVerifyException,
encode_canonical_json,
- sign_json,
signature_ids,
verify_signed_json,
)
@@ -539,13 +538,7 @@ class BaseV2KeyFetcher(object):
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)
- # re-sign the json with our own key, so that it is ready if we are asked to
- # give it out as a notary server
- signed_key_json = sign_json(
- response_json, self.config.server_name, self.config.signing_key[0]
- )
-
- signed_key_json_bytes = encode_canonical_json(signed_key_json)
+ key_json_bytes = encode_canonical_json(response_json)
yield make_deferred_yieldable(
defer.gatherResults(
@@ -557,7 +550,7 @@ class BaseV2KeyFetcher(object):
from_server=from_server,
ts_now_ms=time_added_ms,
ts_expires_ms=ts_valid_until_ms,
- key_json_bytes=signed_key_json_bytes,
+ key_json_bytes=key_json_bytes,
)
for key_id in verify_keys
],
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 0cea0d2a10..482a101c09 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -327,21 +327,37 @@ class TransportLayerClient(object):
include_all_networks=False,
third_party_instance_id=None,
):
- path = _create_v1_path("/publicRooms")
-
- args = {"include_all_networks": "true" if include_all_networks else "false"}
- if third_party_instance_id:
- args["third_party_instance_id"] = (third_party_instance_id,)
- if limit:
- args["limit"] = [str(limit)]
- if since_token:
- args["since"] = [since_token]
-
- # TODO(erikj): Actually send the search_filter across federation.
-
- response = yield self.client.get_json(
- destination=remote_server, path=path, args=args, ignore_backoff=True
- )
+ if search_filter:
+ # this uses MSC2197 (Search Filtering over Federation)
+ path = _create_v1_path("/publicRooms")
+
+ data = {"include_all_networks": "true" if include_all_networks else "false"}
+ if third_party_instance_id:
+ data["third_party_instance_id"] = third_party_instance_id
+ if limit:
+ data["limit"] = str(limit)
+ if since_token:
+ data["since"] = since_token
+
+ data["filter"] = search_filter
+
+ response = yield self.client.post_json(
+ destination=remote_server, path=path, data=data, ignore_backoff=True
+ )
+ else:
+ path = _create_v1_path("/publicRooms")
+
+ args = {"include_all_networks": "true" if include_all_networks else "false"}
+ if third_party_instance_id:
+ args["third_party_instance_id"] = (third_party_instance_id,)
+ if limit:
+ args["limit"] = [str(limit)]
+ if since_token:
+ args["since"] = [since_token]
+
+ response = yield self.client.get_json(
+ destination=remote_server, path=path, args=args, ignore_backoff=True
+ )
return response
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index dc53b4b170..f9930b6460 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -770,6 +770,42 @@ class PublicRoomList(BaseFederationServlet):
)
return 200, data
+ async def on_POST(self, origin, content, query):
+ # This implements MSC2197 (Search Filtering over Federation)
+ if not self.allow_access:
+ raise FederationDeniedError(origin)
+
+ limit = int(content.get("limit", 100))
+ since_token = content.get("since", None)
+ search_filter = content.get("filter", None)
+
+ include_all_networks = content.get("include_all_networks", False)
+ third_party_instance_id = content.get("third_party_instance_id", None)
+
+ if include_all_networks:
+ network_tuple = None
+ if third_party_instance_id is not None:
+ raise SynapseError(
+ 400, "Can't use include_all_networks with an explicit network"
+ )
+ elif third_party_instance_id is None:
+ network_tuple = ThirdPartyInstanceID(None, None)
+ else:
+ network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
+
+ if search_filter is None:
+ logger.warning("Nonefilter")
+
+ data = await self.handler.get_local_public_room_list(
+ limit=limit,
+ since_token=since_token,
+ search_filter=search_filter,
+ network_tuple=network_tuple,
+ from_federation=True,
+ )
+
+ return 200, data
+
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 2cc237e6a5..8690f69d45 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -34,7 +34,7 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
-MAX_DISPLAYNAME_LEN = 100
+MAX_DISPLAYNAME_LEN = 256
MAX_AVATAR_URL_LEN = 1000
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index e9094ad02b..a7e55f00e5 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -25,6 +25,7 @@ from unpaddedbase64 import decode_base64, encode_base64
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.errors import Codes, HttpResponseException
from synapse.types import ThirdPartyInstanceID
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks
@@ -485,7 +486,33 @@ class RoomListHandler(BaseHandler):
return {"chunk": [], "total_room_count_estimate": 0}
if search_filter:
- # We currently don't support searching across federation, so we have
+ # Searching across federation is defined in MSC2197.
+ # However, the remote homeserver may or may not actually support it.
+ # So we first try an MSC2197 remote-filtered search, then fall back
+ # to a locally-filtered search if we must.
+
+ try:
+ res = yield self._get_remote_list_cached(
+ server_name,
+ limit=limit,
+ since_token=since_token,
+ include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
+ search_filter=search_filter,
+ )
+ return res
+ except HttpResponseException as hre:
+ syn_err = hre.to_synapse_error()
+ if hre.code in (404, 405) or syn_err.errcode in (
+ Codes.UNRECOGNIZED,
+ Codes.NOT_FOUND,
+ ):
+ logger.debug("Falling back to locally-filtered /publicRooms")
+ else:
+ raise # Not an error that should trigger a fallback.
+
+ # if we reach this point, then we fall back to the situation where
+ # we currently don't support searching across federation, so we have
# to do it manually without pagination
limit = None
since_token = None
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 64f62aaeec..feae7de5be 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -14,21 +14,21 @@
# limitations under the License.
import logging
+import urllib
-import attr
-from netaddr import IPAddress
+from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.web.client import URI, Agent, HTTPConnectionPool
+from twisted.web.client import Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent
+from twisted.web.iweb import IAgent, IAgentEndpointFactory
-from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
+from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
-from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -36,8 +36,9 @@ logger = logging.getLogger(__name__)
@implementer(IAgent)
class MatrixFederationAgent(object):
- """An Agent-like thing which provides a `request` method which will look up a matrix
- server and send an HTTP request to it.
+ """An Agent-like thing which provides a `request` method which correctly
+ handles resolving matrix server names when using matrix://. Handles standard
+ https URIs as normal.
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
@@ -65,17 +66,19 @@ class MatrixFederationAgent(object):
):
self._reactor = reactor
self._clock = Clock(reactor)
-
- self._tls_client_options_factory = tls_client_options_factory
- if _srv_resolver is None:
- _srv_resolver = SrvResolver()
- self._srv_resolver = _srv_resolver
-
self._pool = HTTPConnectionPool(reactor)
self._pool.retryAutomatically = False
self._pool.maxPersistentPerHost = 5
self._pool.cachedConnectionTimeout = 2 * 60
+ self._agent = Agent.usingEndpointFactory(
+ self._reactor,
+ MatrixHostnameEndpointFactory(
+ reactor, tls_client_options_factory, _srv_resolver
+ ),
+ pool=self._pool,
+ )
+
if _well_known_resolver is None:
_well_known_resolver = WellKnownResolver(
self._reactor,
@@ -93,19 +96,15 @@ class MatrixFederationAgent(object):
"""
Args:
method (bytes): HTTP method: GET/POST/etc
-
uri (bytes): Absolute URI to be retrieved
-
headers (twisted.web.http_headers.Headers|None):
HTTP headers to send with the request, or None to
send no extra headers.
-
bodyProducer (twisted.web.iweb.IBodyProducer|None):
An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or None if the request is to have
no body.
-
Returns:
Deferred[twisted.web.iweb.IResponse]:
fires when the header of the response has been received (regardless of the
@@ -113,210 +112,207 @@ class MatrixFederationAgent(object):
response from being received (including problems that prevent the request
from being sent).
"""
- parsed_uri = URI.fromBytes(uri, defaultPort=-1)
- res = yield self._route_matrix_uri(parsed_uri)
+ # We use urlparse as that will set `port` to None if there is no
+ # explicit port.
+ parsed_uri = urllib.parse.urlparse(uri)
- # set up the TLS connection params
+ # If this is a matrix:// URI check if the server has delegated matrix
+ # traffic using well-known delegation.
#
- # XXX disabling TLS is really only supported here for the benefit of the
- # unit tests. We should make the UTs cope with TLS rather than having to make
- # the code support the unit tests.
- if self._tls_client_options_factory is None:
- tls_options = None
- else:
- tls_options = self._tls_client_options_factory.get_options(
- res.tls_server_name.decode("ascii")
+ # We have to do this here and not in the endpoint as we need to rewrite
+ # the host header with the delegated server name.
+ delegated_server = None
+ if (
+ parsed_uri.scheme == b"matrix"
+ and not _is_ip_literal(parsed_uri.hostname)
+ and not parsed_uri.port
+ ):
+ well_known_result = yield self._well_known_resolver.get_well_known(
+ parsed_uri.hostname
)
+ delegated_server = well_known_result.delegated_server
+
+ if delegated_server:
+ # Ok, the server has delegated matrix traffic to somewhere else, so
+ # lets rewrite the URL to replace the server with the delegated
+ # server name.
+ uri = urllib.parse.urlunparse(
+ (
+ parsed_uri.scheme,
+ delegated_server,
+ parsed_uri.path,
+ parsed_uri.params,
+ parsed_uri.query,
+ parsed_uri.fragment,
+ )
+ )
+ parsed_uri = urllib.parse.urlparse(uri)
- # make sure that the Host header is set correctly
+ # We need to make sure the host header is set to the netloc of the
+ # server.
if headers is None:
headers = Headers()
else:
headers = headers.copy()
if not headers.hasHeader(b"host"):
- headers.addRawHeader(b"host", res.host_header)
+ headers.addRawHeader(b"host", parsed_uri.netloc)
- class EndpointFactory(object):
- @staticmethod
- def endpointForURI(_uri):
- ep = LoggingHostnameEndpoint(
- self._reactor, res.target_host, res.target_port
- )
- if tls_options is not None:
- ep = wrapClientTLS(tls_options, ep)
- return ep
-
- agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
res = yield make_deferred_yieldable(
- agent.request(method, uri, headers, bodyProducer)
+ self._agent.request(method, uri, headers, bodyProducer)
)
+
return res
- @defer.inlineCallbacks
- def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
- """Helper for `request`: determine the routing for a Matrix URI
- Args:
- parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
- parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
- if there is no explicit port given.
+@implementer(IAgentEndpointFactory)
+class MatrixHostnameEndpointFactory(object):
+ """Factory for MatrixHostnameEndpoint for parsing to an Agent.
+ """
- lookup_well_known (bool): True if we should look up the .well-known file if
- there is no SRV record.
+ def __init__(self, reactor, tls_client_options_factory, srv_resolver):
+ self._reactor = reactor
+ self._tls_client_options_factory = tls_client_options_factory
- Returns:
- Deferred[_RoutingResult]
- """
- # check for an IP literal
- try:
- ip_address = IPAddress(parsed_uri.host.decode("ascii"))
- except Exception:
- # not an IP address
- ip_address = None
-
- if ip_address:
- port = parsed_uri.port
- if port == -1:
- port = 8448
- return _RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=port,
- )
+ if srv_resolver is None:
+ srv_resolver = SrvResolver()
- if parsed_uri.port != -1:
- # there is an explicit port
- return _RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=parsed_uri.port,
- )
+ self._srv_resolver = srv_resolver
- if lookup_well_known:
- # try a .well-known lookup
- well_known_result = yield self._well_known_resolver.get_well_known(
- parsed_uri.host
- )
- well_known_server = well_known_result.delegated_server
-
- if well_known_server:
- # if we found a .well-known, start again, but don't do another
- # .well-known lookup.
-
- # parse the server name in the .well-known response into host/port.
- # (This code is lifted from twisted.web.client.URI.fromBytes).
- if b":" in well_known_server:
- well_known_host, well_known_port = well_known_server.rsplit(b":", 1)
- try:
- well_known_port = int(well_known_port)
- except ValueError:
- # the part after the colon could not be parsed as an int
- # - we assume it is an IPv6 literal with no port (the closing
- # ']' stops it being parsed as an int)
- well_known_host, well_known_port = well_known_server, -1
- else:
- well_known_host, well_known_port = well_known_server, -1
-
- new_uri = URI(
- scheme=parsed_uri.scheme,
- netloc=well_known_server,
- host=well_known_host,
- port=well_known_port,
- path=parsed_uri.path,
- params=parsed_uri.params,
- query=parsed_uri.query,
- fragment=parsed_uri.fragment,
- )
+ def endpointForURI(self, parsed_uri):
+ return MatrixHostnameEndpoint(
+ self._reactor,
+ self._tls_client_options_factory,
+ self._srv_resolver,
+ parsed_uri,
+ )
- res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
- return res
-
- # try a SRV lookup
- service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
- server_list = yield self._srv_resolver.resolve_service(service_name)
-
- if not server_list:
- target_host = parsed_uri.host
- port = 8448
- logger.debug(
- "No SRV record for %s, using %s:%i",
- parsed_uri.host.decode("ascii"),
- target_host.decode("ascii"),
- port,
- )
+
+@implementer(IStreamClientEndpoint)
+class MatrixHostnameEndpoint(object):
+ """An endpoint that resolves matrix:// URLs using Matrix server name
+ resolution (i.e. via SRV). Does not check for well-known delegation.
+
+ Args:
+ reactor (IReactor)
+ tls_client_options_factory (ClientTLSOptionsFactory|None):
+ factory to use for fetching client tls options, or none to disable TLS.
+ srv_resolver (SrvResolver): The SRV resolver to use
+ parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting
+ to connect to.
+ """
+
+ def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
+ self._reactor = reactor
+
+ self._parsed_uri = parsed_uri
+
+ # set up the TLS connection params
+ #
+ # XXX disabling TLS is really only supported here for the benefit of the
+ # unit tests. We should make the UTs cope with TLS rather than having to make
+ # the code support the unit tests.
+
+ if tls_client_options_factory is None:
+ self._tls_options = None
else:
- target_host, port = pick_server_from_list(server_list)
- logger.debug(
- "Picked %s:%i from SRV records for %s",
- target_host.decode("ascii"),
- port,
- parsed_uri.host.decode("ascii"),
+ self._tls_options = tls_client_options_factory.get_options(
+ self._parsed_uri.host.decode("ascii")
)
- return _RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=target_host,
- target_port=port,
- )
+ self._srv_resolver = srv_resolver
+ def connect(self, protocol_factory):
+ """Implements IStreamClientEndpoint interface
+ """
-@implementer(IStreamClientEndpoint)
-class LoggingHostnameEndpoint(object):
- """A wrapper for HostnameEndpint which logs when it connects"""
+ return run_in_background(self._do_connect, protocol_factory)
- def __init__(self, reactor, host, port, *args, **kwargs):
- self.host = host
- self.port = port
- self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)
+ @defer.inlineCallbacks
+ def _do_connect(self, protocol_factory):
+ first_exception = None
+
+ server_list = yield self._resolve_server()
+
+ for server in server_list:
+ host = server.host
+ port = server.port
+
+ try:
+ logger.info("Connecting to %s:%i", host.decode("ascii"), port)
+ endpoint = HostnameEndpoint(self._reactor, host, port)
+ if self._tls_options:
+ endpoint = wrapClientTLS(self._tls_options, endpoint)
+ result = yield make_deferred_yieldable(
+ endpoint.connect(protocol_factory)
+ )
- def connect(self, protocol_factory):
- logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port)
- return self.ep.connect(protocol_factory)
+ return result
+ except Exception as e:
+ logger.info(
+ "Failed to connect to %s:%i: %s", host.decode("ascii"), port, e
+ )
+ if not first_exception:
+ first_exception = e
+ # We return the first failure because that's probably the most interesting.
+ if first_exception:
+ raise first_exception
-@attr.s
-class _RoutingResult(object):
- """The result returned by `_route_matrix_uri`.
+ # This shouldn't happen as we should always have at least one host/port
+ # to try and if that doesn't work then we'll have an exception.
+ raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
- Contains the parameters needed to direct a federation connection to a particular
- server.
+ @defer.inlineCallbacks
+ def _resolve_server(self):
+ """Resolves the server name to a list of hosts and ports to attempt to
+ connect to.
- Where a SRV record points to several servers, this object contains a single server
- chosen from the list.
- """
+ Returns:
+ Deferred[list[Server]]
+ """
- host_header = attr.ib()
- """
- The value we should assign to the Host header (host:port from the matrix
- URI, or .well-known).
+ if self._parsed_uri.scheme != b"matrix":
+ return [Server(host=self._parsed_uri.host, port=self._parsed_uri.port)]
- :type: bytes
- """
+ # Note: We don't do well-known lookup as that needs to have happened
+ # before now, due to needing to rewrite the Host header of the HTTP
+ # request.
- tls_server_name = attr.ib()
- """
- The server name we should set in the SNI (typically host, without port, from the
- matrix URI or .well-known)
+ # We reparse the URI so that defaultPort is -1 rather than 80
+ parsed_uri = urllib.parse.urlparse(self._parsed_uri.toBytes())
- :type: bytes
- """
+ host = parsed_uri.hostname
+ port = parsed_uri.port
- target_host = attr.ib()
- """
- The hostname (or IP literal) we should route the TCP connection to (the target of the
- SRV record, or the hostname from the URL/.well-known)
+ # If there is an explicit port or the host is an IP address we bypass
+ # SRV lookups and just use the given host/port.
+ if port or _is_ip_literal(host):
+ return [Server(host, port or 8448)]
- :type: bytes
- """
+ server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
- target_port = attr.ib()
- """
- The port we should route the TCP connection to (the target of the SRV record, or
- the port from the URL/.well-known, or 8448)
+ if server_list:
+ return server_list
+
+ # No SRV records, so we fallback to host and 8448
+ return [Server(host, 8448)]
+
+
+def _is_ip_literal(host):
+ """Test if the given host name is either an IPv4 or IPv6 literal.
- :type: int
+ Args:
+ host (bytes)
+
+ Returns:
+ bool
"""
+
+ host = host.decode("ascii")
+
+ try:
+ IPAddress(host)
+ return True
+ except AddrFormatError:
+ return False
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index b32188766d..3fe4ffb9e5 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
SERVER_CACHE = {}
-@attr.s
+@attr.s(slots=True, frozen=True)
class Server(object):
"""
Our record of an individual server which can be tried to reach a destination.
@@ -53,34 +53,47 @@ class Server(object):
expires = attr.ib(default=0)
-def pick_server_from_list(server_list):
- """Randomly choose a server from the server list
+def _sort_server_list(server_list):
+ """Given a list of SRV records sort them into priority order and shuffle
+ each priority with the given weight.
+ """
+ priority_map = {}
- Args:
- server_list (list[Server]): list of candidate servers
+ for server in server_list:
+ priority_map.setdefault(server.priority, []).append(server)
- Returns:
- Tuple[bytes, int]: (host, port) pair for the chosen server
- """
- if not server_list:
- raise RuntimeError("pick_server_from_list called with empty list")
+ results = []
+ for priority in sorted(priority_map):
+ servers = priority_map[priority]
+
+ # This algorithms roughly follows the algorithm described in RFC2782,
+ # changed to remove an off-by-one error.
+ #
+ # N.B. Weights can be zero, which means that they should be picked
+ # rarely.
+
+ total_weight = sum(s.weight for s in servers)
+
+ # Total weight can become zero if there are only zero weight servers
+ # left, which we handle by just shuffling and appending to the results.
+ while servers and total_weight:
+ target_weight = random.randint(1, total_weight)
- # TODO: currently we only use the lowest-priority servers. We should maintain a
- # cache of servers known to be "down" and filter them out
+ for s in servers:
+ target_weight -= s.weight
- min_priority = min(s.priority for s in server_list)
- eligible_servers = list(s for s in server_list if s.priority == min_priority)
- total_weight = sum(s.weight for s in eligible_servers)
- target_weight = random.randint(0, total_weight)
+ if target_weight <= 0:
+ break
- for s in eligible_servers:
- target_weight -= s.weight
+ results.append(s)
+ servers.remove(s)
+ total_weight -= s.weight
- if target_weight <= 0:
- return s.host, s.port
+ if servers:
+ random.shuffle(servers)
+ results.extend(servers)
- # this should be impossible.
- raise RuntimeError("pick_server_from_list got to end of eligible server list.")
+ return results
class SrvResolver(object):
@@ -120,7 +133,7 @@ class SrvResolver(object):
if cache_entry:
if all(s.expires > now for s in cache_entry):
servers = list(cache_entry)
- return servers
+ return _sort_server_list(servers)
try:
answers, _, _ = yield make_deferred_yieldable(
@@ -169,4 +182,4 @@ class SrvResolver(object):
)
self._cache[service_name] = list(servers)
- return servers
+ return _sort_server_list(servers)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 031a316693..55580bc59e 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,7 +13,9 @@
# limitations under the License.
import logging
-from io import BytesIO
+
+from canonicaljson import encode_canonical_json, json
+from signedjson.sign import sign_json
from twisted.internet import defer
@@ -95,6 +97,7 @@ class RemoteKey(DirectServeResource):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+ self.config = hs.config
@wrap_json_request_handler
async def _async_render_GET(self, request):
@@ -214,15 +217,14 @@ class RemoteKey(DirectServeResource):
yield self.fetcher.get_keys(cache_misses)
yield self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
- result_io = BytesIO()
- result_io.write(b'{"server_keys":')
- sep = b"["
- for json_bytes in json_results:
- result_io.write(sep)
- result_io.write(json_bytes)
- sep = b","
- if sep == b"[":
- result_io.write(sep)
- result_io.write(b"]}")
-
- respond_with_json_bytes(request, 200, result_io.getvalue())
+ signed_keys = []
+ for key_json in json_results:
+ key_json = json.loads(key_json)
+ for signing_key in self.config.key_server_signing_keys:
+ key_json = sign_json(key_json, self.config.server_name, signing_key)
+
+ signed_keys.append(key_json)
+
+ results = {"server_keys": signed_keys}
+
+ respond_with_json_bytes(request, 200, encode_canonical_json(results))
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 5e8fda4b65..20177b44e7 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -34,7 +34,7 @@ class WellKnownBuilder(object):
self._config = hs.config
def get_well_known(self):
- # if we don't have a public_base_url, we can't help much here.
+ # if we don't have a public_baseurl, we can't help much here.
if self._config.public_baseurl is None:
return None
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index d20eacda59..e96eed8a6d 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -238,6 +238,13 @@ def _upgrade_existing_database(
logger.debug("applied_delta_files: %s", applied_delta_files)
+ if isinstance(database_engine, PostgresEngine):
+ specific_engine_extension = ".postgres"
+ else:
+ specific_engine_extension = ".sqlite"
+
+ specific_engine_extensions = (".sqlite", ".postgres")
+
for v in range(start_ver, SCHEMA_VERSION + 1):
logger.info("Upgrading schema to v%d", v)
@@ -274,15 +281,22 @@ def _upgrade_existing_database(
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
# installers. Silently skip it
- pass
+ continue
elif ext == ".sql":
# A plain old .sql file, just read and execute it
logger.info("Applying schema %s", relative_path)
executescript(cur, absolute_path)
+ elif ext == specific_engine_extension and root_name.endswith(".sql"):
+ # A .sql file specific to our engine; just read and execute it
+ logger.info("Applying engine-specific schema %s", relative_path)
+ executescript(cur, absolute_path)
+ elif ext in specific_engine_extensions and root_name.endswith(".sql"):
+ # A .sql file for a different engine; skip it.
+ continue
else:
# Not a valid delta file.
- logger.warn(
- "Found directory entry that did not end in .py or" " .sql: %s",
+ logger.warning(
+ "Found directory entry that did not end in .py or .sql: %s",
relative_path,
)
continue
@@ -290,7 +304,7 @@ def _upgrade_existing_database(
# Mark as done.
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO applied_schema_deltas (version, file)" " VALUES (?,?)"
+ "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)"
),
(v, relative_path),
)
@@ -298,7 +312,7 @@ def _upgrade_existing_database(
cur.execute("DELETE FROM schema_version")
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
+ "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
),
(v, True),
)
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index af15f4cc5a..b08be451aa 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -20,7 +20,6 @@ from synapse.federation.federation_server import server_matches_acl_event
from tests import unittest
-@unittest.DEBUG
class ServerACLsTestCase(unittest.TestCase):
def test_blacklisted_server(self):
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index c55aad8e11..71d7025264 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -41,9 +41,9 @@ from synapse.http.federation.well_known_resolver import (
from synapse.logging.context import LoggingContext
from synapse.util.caches.ttlcache import TTLCache
+from tests import unittest
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.server import FakeTransport, ThreadedMemoryReactorClock
-from tests.unittest import TestCase
from tests.utils import default_config
logger = logging.getLogger(__name__)
@@ -67,7 +67,7 @@ def get_connection_factory():
return test_server_connection_factory
-class MatrixFederationAgentTests(TestCase):
+class MatrixFederationAgentTests(unittest.TestCase):
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
@@ -1069,8 +1069,64 @@ class MatrixFederationAgentTests(TestCase):
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, None)
+ def test_srv_fallbacks(self):
+ """Test that other SRV results are tried if the first one fails.
+ """
+
+ self.mock_resolver.resolve_service.side_effect = lambda _: [
+ Server(host=b"target.com", port=8443),
+ Server(host=b"target.com", port=8444),
+ ]
+ self.reactor.lookups["target.com"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.testserv"
+ )
+
+ # We should see an attempt to connect to the first server
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8443)
+
+ # Fonx the connection
+ client_factory.clientConnectionFailed(None, Exception("nope"))
+
+ # There's a 300ms delay in HostnameEndpoint
+ self.reactor.pump((0.4,))
+
+ # Hasn't failed yet
+ self.assertNoResult(test_d)
+
+ # We shouldnow see an attempt to connect to the second server
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8444)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(client_factory, expected_sni=b"testserv")
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/foo/bar")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
-class TestCachePeriodFromHeaders(TestCase):
+class TestCachePeriodFromHeaders(unittest.TestCase):
def test_cache_control(self):
# uppercase
self.assertEqual(
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 3b885ef64b..df034ab237 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -83,8 +83,10 @@ class SrvResolverTestCase(unittest.TestCase):
service_name = b"test_service.example.com"
- entry = Mock(spec_set=["expires"])
+ entry = Mock(spec_set=["expires", "priority", "weight"])
entry.expires = 0
+ entry.priority = 0
+ entry.weight = 0
cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
@@ -105,8 +107,10 @@ class SrvResolverTestCase(unittest.TestCase):
service_name = b"test_service.example.com"
- entry = Mock(spec_set=["expires"])
+ entry = Mock(spec_set=["expires", "priority", "weight"])
entry.expires = 999999999
+ entry.priority = 0
+ entry.weight = 0
cache = {service_name: [entry]}
resolver = SrvResolver(
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index e0605dac2f..18f1a0035d 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -74,7 +74,6 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertEqual(filtered[i].content["a"], "b")
- @tests.unittest.DEBUG
@defer.inlineCallbacks
def test_erased_user(self):
# 4 message events, from erased and unerased users, with a membership
|