diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/app/__init__.py | 6 | ||||
-rwxr-xr-x | synapse/app/homeserver.py | 3 | ||||
-rw-r--r-- | synapse/config/_base.py | 56 | ||||
-rw-r--r-- | synapse/config/database.py | 5 | ||||
-rw-r--r-- | synapse/config/homeserver.py | 7 | ||||
-rw-r--r-- | synapse/config/key.py | 27 | ||||
-rw-r--r-- | synapse/config/logger.py | 4 | ||||
-rw-r--r-- | synapse/config/metrics.py | 12 | ||||
-rw-r--r-- | synapse/config/registration.py | 11 | ||||
-rw-r--r-- | synapse/config/repository.py | 8 | ||||
-rw-r--r-- | synapse/config/server.py | 75 | ||||
-rw-r--r-- | synapse/http/client.py | 377 | ||||
-rw-r--r-- | synapse/http/endpoint.py | 35 | ||||
-rw-r--r-- | synapse/python_dependencies.py | 228 | ||||
-rw-r--r-- | synapse/rest/media/v1/preview_url_resource.py | 14 |
15 files changed, 469 insertions, 399 deletions
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index c3afcc573b..233bf43fc8 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -22,11 +22,11 @@ sys.dont_write_bytecode = True try: python_dependencies.check_requirements() -except python_dependencies.MissingRequirementError as e: +except python_dependencies.DependencyException as e: message = "\n".join([ - "Missing Requirement: %s" % (str(e),), + "Missing Requirements: %s" % (", ".join(e.dependencies),), "To install run:", - " pip install --upgrade --force \"%s\"" % (e.dependency,), + " pip install --upgrade --force %s" % (" ".join(e.dependencies),), "", ]) sys.stderr.writelines(message) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index f2064f9d0c..f3ac3d19f0 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -322,9 +322,6 @@ def setup(config_options): synapse.config.logger.setup_logging(config, use_worker_options=False) - # check any extra requirements we have now we have a config - check_requirements(config) - events.USE_FROZEN_DICTS = config.use_frozen_dicts tls_server_context_factory = context_factory.ServerContextFactory(config) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 14dae65ea0..fd2d6d52ef 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -135,10 +135,6 @@ class Config(object): return file_stream.read() @staticmethod - def default_path(name): - return os.path.abspath(os.path.join(os.path.curdir, name)) - - @staticmethod def read_config_file(file_path): with open(file_path) as file_stream: return yaml.load(file_stream) @@ -151,8 +147,39 @@ class Config(object): return results def generate_config( - self, config_dir_path, server_name, is_generating_file, report_stats=None + self, + config_dir_path, + data_dir_path, + server_name, + generate_secrets=False, + report_stats=None, ): + """Build a default configuration file + + This is used both when the user explicitly asks us to generate a config file + (eg with --generate_config), and before loading the config at runtime (to give + a base which the config files override) + + Args: + config_dir_path (str): The path where the config files are kept. Used to + create filenames for things like the log config and the signing key. + + data_dir_path (str): The path where the data files are kept. Used to create + filenames for things like the database and media store. + + server_name (str): The server name. Used to initialise the server_name + config param, but also used in the names of some of the config files. + + generate_secrets (bool): True if we should generate new secrets for things + like the macaroon_secret_key. If False, these parameters will be left + unset. + + report_stats (bool|None): Initial setting for the report_stats setting. + If None, report_stats will be left unset. + + Returns: + str: the yaml config file + """ default_config = "# vim:ft=yaml\n" default_config += "\n\n".join( @@ -160,15 +187,14 @@ class Config(object): for conf in self.invoke_all( "default_config", config_dir_path=config_dir_path, + data_dir_path=data_dir_path, server_name=server_name, - is_generating_file=is_generating_file, + generate_secrets=generate_secrets, report_stats=report_stats, ) ) - config = yaml.load(default_config) - - return default_config, config + return default_config @classmethod def load_config(cls, description, argv): @@ -274,12 +300,14 @@ class Config(object): if not cls.path_exists(config_dir_path): os.makedirs(config_dir_path) with open(config_path, "w") as config_file: - config_str, config = obj.generate_config( + config_str = obj.generate_config( config_dir_path=config_dir_path, + data_dir_path=os.getcwd(), server_name=server_name, report_stats=(config_args.report_stats == "yes"), - is_generating_file=True, + generate_secrets=True, ) + config = yaml.load(config_str) obj.invoke_all("generate_files", config) config_file.write(config_str) print( @@ -350,11 +378,13 @@ class Config(object): raise ConfigError(MISSING_SERVER_NAME) server_name = specified_config["server_name"] - _, config = self.generate_config( + config_string = self.generate_config( config_dir_path=config_dir_path, + data_dir_path=os.getcwd(), server_name=server_name, - is_generating_file=False, + generate_secrets=False, ) + config = yaml.load(config_string) config.pop("log_config") config.update(specified_config) diff --git a/synapse/config/database.py b/synapse/config/database.py index e915d9d09b..c8890147a6 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from ._base import Config @@ -45,8 +46,8 @@ class DatabaseConfig(Config): self.set_databasepath(config.get("database_path")) - def default_config(self, **kwargs): - database_path = self.abspath("homeserver.db") + def default_config(self, data_dir_path, **kwargs): + database_path = os.path.join(data_dir_path, "homeserver.db") return """\ # Database configuration database: diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 9d740c7a71..5aad062c36 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -53,10 +53,3 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, ServerNoticesConfig, RoomDirectoryConfig, ): pass - - -if __name__ == '__main__': - import sys - sys.stdout.write( - HomeServerConfig().generate_config(sys.argv[1], sys.argv[2], True)[0] - ) diff --git a/synapse/config/key.py b/synapse/config/key.py index 279c47bb48..53f48fe2dd 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -66,26 +66,35 @@ class KeyConfig(Config): # falsification of values self.form_secret = config.get("form_secret", None) - def default_config(self, config_dir_path, server_name, is_generating_file=False, + def default_config(self, config_dir_path, server_name, generate_secrets=False, **kwargs): base_key_name = os.path.join(config_dir_path, server_name) - if is_generating_file: - macaroon_secret_key = random_string_with_symbols(50) - form_secret = '"%s"' % random_string_with_symbols(50) + if generate_secrets: + macaroon_secret_key = 'macaroon_secret_key: "%s"' % ( + random_string_with_symbols(50), + ) + form_secret = 'form_secret: "%s"' % random_string_with_symbols(50) else: - macaroon_secret_key = None - form_secret = 'null' + macaroon_secret_key = "# macaroon_secret_key: <PRIVATE STRING>" + form_secret = "# form_secret: <PRIVATE STRING>" return """\ - macaroon_secret_key: "%(macaroon_secret_key)s" + # a secret which is used to sign access tokens. If none is specified, + # the registration_shared_secret is used, if one is given; otherwise, + # a secret key is derived from the signing key. + # + # Note that changing this will invalidate any active access tokens, so + # all clients will have to log back in. + %(macaroon_secret_key)s # Used to enable access token expiration. expire_access_token: False # a secret which is used to calculate HMACs for form values, to stop - # falsification of values - form_secret: %(form_secret)s + # falsification of values. Must be specified for the User Consent + # forms to work. + %(form_secret)s ## Signing Keys ## diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 7081868963..f87efecbf8 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -80,9 +80,7 @@ class LoggingConfig(Config): self.log_file = self.abspath(config.get("log_file")) def default_config(self, config_dir_path, server_name, **kwargs): - log_config = self.abspath( - os.path.join(config_dir_path, server_name + ".log.config") - ) + log_config = os.path.join(config_dir_path, server_name + ".log.config") return """ # A yaml python logging config file log_config: "%(log_config)s" diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 61155c99d0..718c43ae03 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -24,10 +24,16 @@ class MetricsConfig(Config): self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1") def default_config(self, report_stats=None, **kwargs): - suffix = "" if report_stats is None else "report_stats: %(report_stats)s\n" - return ("""\ + res = """\ ## Metrics ### # Enable collection and rendering of performance metrics enable_metrics: False - """ + suffix) % locals() + """ + + if report_stats is None: + res += "# report_stats: true|false\n" + else: + res += "report_stats: %s\n" % ('true' if report_stats else 'false') + + return res diff --git a/synapse/config/registration.py b/synapse/config/registration.py index e365f0c30b..6c2b543b8c 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -50,8 +50,13 @@ class RegistrationConfig(Config): raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) - def default_config(self, **kwargs): - registration_shared_secret = random_string_with_symbols(50) + def default_config(self, generate_secrets=False, **kwargs): + if generate_secrets: + registration_shared_secret = 'registration_shared_secret: "%s"' % ( + random_string_with_symbols(50), + ) + else: + registration_shared_secret = '# registration_shared_secret: <PRIVATE STRING>' return """\ ## Registration ## @@ -78,7 +83,7 @@ class RegistrationConfig(Config): # If set, allows registration by anyone who also has the shared # secret, even if registration is otherwise disabled. - registration_shared_secret: "%(registration_shared_secret)s" + %(registration_shared_secret)s # Set the number of bcrypt rounds used to generate password hash. # Larger numbers increase the work factor needed to generate the hash. diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 06c62ab62c..76e3340a91 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from collections import namedtuple from synapse.util.module_loader import load_module @@ -175,9 +175,9 @@ class ContentRepositoryConfig(Config): "url_preview_url_blacklist", () ) - def default_config(self, **kwargs): - media_store = self.default_path("media_store") - uploads_path = self.default_path("uploads") + def default_config(self, data_dir_path, **kwargs): + media_store = os.path.join(data_dir_path, "media_store") + uploads_path = os.path.join(data_dir_path, "uploads") return r""" # Directory where uploaded images and attachments are stored. media_store_path: "%(media_store)s" diff --git a/synapse/config/server.py b/synapse/config/server.py index a9154ad462..120c2b81fc 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +import os.path from synapse.http.endpoint import parse_and_validate_server_name @@ -203,7 +204,7 @@ class ServerConfig(Config): ] }) - def default_config(self, server_name, **kwargs): + def default_config(self, server_name, data_dir_path, **kwargs): _, bind_port = parse_and_validate_server_name(server_name) if bind_port is not None: unsecure_port = bind_port - 400 @@ -211,7 +212,7 @@ class ServerConfig(Config): bind_port = 8448 unsecure_port = 8008 - pid_file = self.abspath("homeserver.pid") + pid_file = os.path.join(data_dir_path, "homeserver.pid") return """\ ## Server ## @@ -356,41 +357,41 @@ class ServerConfig(Config): # type: manhole - # Homeserver blocking - # - # How to reach the server admin, used in ResourceLimitError - # admin_contact: 'mailto:admin@server.com' - # - # Global block config - # - # hs_disabled: False - # hs_disabled_message: 'Human readable reason for why the HS is blocked' - # hs_disabled_limit_type: 'error code(str), to help clients decode reason' - # - # Monthly Active User Blocking - # - # Enables monthly active user checking - # limit_usage_by_mau: False - # max_mau_value: 50 - # mau_trial_days: 2 - # - # If enabled, the metrics for the number of monthly active users will - # be populated, however no one will be limited. If limit_usage_by_mau - # is true, this is implied to be true. - # mau_stats_only: False - # - # Sometimes the server admin will want to ensure certain accounts are - # never blocked by mau checking. These accounts are specified here. - # - # mau_limit_reserved_threepids: - # - medium: 'email' - # address: 'reserved_user@example.com' - # - # Room searching - # - # If disabled, new messages will not be indexed for searching and users - # will receive errors when searching for messages. Defaults to enabled. - # enable_search: true + # Homeserver blocking + # + # How to reach the server admin, used in ResourceLimitError + # admin_contact: 'mailto:admin@server.com' + # + # Global block config + # + # hs_disabled: False + # hs_disabled_message: 'Human readable reason for why the HS is blocked' + # hs_disabled_limit_type: 'error code(str), to help clients decode reason' + # + # Monthly Active User Blocking + # + # Enables monthly active user checking + # limit_usage_by_mau: False + # max_mau_value: 50 + # mau_trial_days: 2 + # + # If enabled, the metrics for the number of monthly active users will + # be populated, however no one will be limited. If limit_usage_by_mau + # is true, this is implied to be true. + # mau_stats_only: False + # + # Sometimes the server admin will want to ensure certain accounts are + # never blocked by mau checking. These accounts are specified here. + # + # mau_limit_reserved_threepids: + # - medium: 'email' + # address: 'reserved_user@example.com' + # + # Room searching + # + # If disabled, new messages will not be indexed for searching and users + # will receive errors when searching for messages. Defaults to enabled. + # enable_search: true """ % locals() def read_arguments(self, args): diff --git a/synapse/http/client.py b/synapse/http/client.py index 3d05f83b8c..afcf698b29 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -21,28 +21,25 @@ from six.moves import urllib import treq from canonicaljson import encode_canonical_json, json +from netaddr import IPAddress from prometheus_client import Counter +from zope.interface import implementer, provider from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE -from twisted.internet import defer, protocol, reactor, ssl -from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS -from twisted.web._newclient import ResponseDone -from twisted.web.client import ( - Agent, - BrowserLikeRedirectAgent, - ContentDecoderAgent, - GzipDecoder, - HTTPConnectionPool, - PartialDownloadError, - readBody, +from twisted.internet import defer, protocol, ssl +from twisted.internet.interfaces import ( + IReactorPluggableNameResolver, + IResolutionReceiver, ) +from twisted.python.failure import Failure +from twisted.web._newclient import ResponseDone +from twisted.web.client import Agent, HTTPConnectionPool, PartialDownloadError, readBody from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import cancelled_to_request_timed_out_error, redact_uri -from synapse.http.endpoint import SpiderEndpoint from synapse.util.async_helpers import timeout_deferred from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.logcontext import make_deferred_yieldable @@ -50,8 +47,125 @@ from synapse.util.logcontext import make_deferred_yieldable logger = logging.getLogger(__name__) outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"]) -incoming_responses_counter = Counter("synapse_http_client_responses", "", - ["method", "code"]) +incoming_responses_counter = Counter( + "synapse_http_client_responses", "", ["method", "code"] +) + + +def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist): + """ + Args: + ip_address (netaddr.IPAddress) + ip_whitelist (netaddr.IPSet) + ip_blacklist (netaddr.IPSet) + """ + if ip_address in ip_blacklist: + if ip_whitelist is None or ip_address not in ip_whitelist: + return True + return False + + +class IPBlacklistingResolver(object): + """ + A proxy for reactor.nameResolver which only produces non-blacklisted IP + addresses, preventing DNS rebinding attacks on URL preview. + """ + + def __init__(self, reactor, ip_whitelist, ip_blacklist): + """ + Args: + reactor (twisted.internet.reactor) + ip_whitelist (netaddr.IPSet) + ip_blacklist (netaddr.IPSet) + """ + self._reactor = reactor + self._ip_whitelist = ip_whitelist + self._ip_blacklist = ip_blacklist + + def resolveHostName(self, recv, hostname, portNumber=0): + + r = recv() + d = defer.Deferred() + addresses = [] + + @provider(IResolutionReceiver) + class EndpointReceiver(object): + @staticmethod + def resolutionBegan(resolutionInProgress): + pass + + @staticmethod + def addressResolved(address): + ip_address = IPAddress(address.host) + + if check_against_blacklist( + ip_address, self._ip_whitelist, self._ip_blacklist + ): + logger.info( + "Dropped %s from DNS resolution to %s" % (ip_address, hostname) + ) + raise SynapseError(403, "IP address blocked by IP blacklist entry") + + addresses.append(address) + + @staticmethod + def resolutionComplete(): + d.callback(addresses) + + self._reactor.nameResolver.resolveHostName( + EndpointReceiver, hostname, portNumber=portNumber + ) + + def _callback(addrs): + r.resolutionBegan(None) + for i in addrs: + r.addressResolved(i) + r.resolutionComplete() + + d.addCallback(_callback) + + return r + + +class BlacklistingAgentWrapper(Agent): + """ + An Agent wrapper which will prevent access to IP addresses being accessed + directly (without an IP address lookup). + """ + + def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None): + """ + Args: + agent (twisted.web.client.Agent): The Agent to wrap. + reactor (twisted.internet.reactor) + ip_whitelist (netaddr.IPSet) + ip_blacklist (netaddr.IPSet) + """ + self._agent = agent + self._ip_whitelist = ip_whitelist + self._ip_blacklist = ip_blacklist + + def request(self, method, uri, headers=None, bodyProducer=None): + h = urllib.parse.urlparse(uri.decode('ascii')) + + try: + ip_address = IPAddress(h.hostname) + + if check_against_blacklist( + ip_address, self._ip_whitelist, self._ip_blacklist + ): + logger.info( + "Blocking access to %s because of blacklist" % (ip_address,) + ) + e = SynapseError(403, "IP address blocked by IP blacklist entry") + return defer.fail(Failure(e)) + except Exception: + # Not an IP + pass + + return self._agent.request( + method, uri, headers=headers, bodyProducer=bodyProducer + ) class SimpleHttpClient(object): @@ -59,14 +173,54 @@ class SimpleHttpClient(object): A simple, no-frills HTTP client with methods that wrap up common ways of using HTTP in Matrix """ - def __init__(self, hs): + + def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None): + """ + Args: + hs (synapse.server.HomeServer) + treq_args (dict): Extra keyword arguments to be given to treq.request. + ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that + we may not request. + ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can + request if it were otherwise caught in a blacklist. + """ self.hs = hs - pool = HTTPConnectionPool(reactor) + self._ip_whitelist = ip_whitelist + self._ip_blacklist = ip_blacklist + self._extra_treq_args = treq_args + + self.user_agent = hs.version_string + self.clock = hs.get_clock() + if hs.config.user_agent_suffix: + self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) + + self.user_agent = self.user_agent.encode('ascii') + + if self._ip_blacklist: + real_reactor = hs.get_reactor() + # If we have an IP blacklist, we need to use a DNS resolver which + # filters out blacklisted IP addresses, to prevent DNS rebinding. + nameResolver = IPBlacklistingResolver( + real_reactor, self._ip_whitelist, self._ip_blacklist + ) + + @implementer(IReactorPluggableNameResolver) + class Reactor(object): + def __getattr__(_self, attr): + if attr == "nameResolver": + return nameResolver + else: + return getattr(real_reactor, attr) + + self.reactor = Reactor() + else: + self.reactor = hs.get_reactor() # the pusher makes lots of concurrent SSL connections to sygnal, and - # tends to do so in batches, so we need to allow the pool to keep lots - # of idle connections around. + # tends to do so in batches, so we need to allow the pool to keep + # lots of idle connections around. + pool = HTTPConnectionPool(self.reactor) pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) pool.cachedConnectionTimeout = 2 * 60 @@ -74,20 +228,35 @@ class SimpleHttpClient(object): # BrowserLikePolicyForHTTPS which will do regular cert validation # 'like a browser' self.agent = Agent( - reactor, + self.reactor, connectTimeout=15, - contextFactory=hs.get_http_client_context_factory(), + contextFactory=self.hs.get_http_client_context_factory(), pool=pool, ) - self.user_agent = hs.version_string - self.clock = hs.get_clock() - if hs.config.user_agent_suffix: - self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) - self.user_agent = self.user_agent.encode('ascii') + if self._ip_blacklist: + # If we have an IP blacklist, we then install the blacklisting Agent + # which prevents direct access to IP addresses, that are not caught + # by the DNS resolution. + self.agent = BlacklistingAgentWrapper( + self.agent, + self.reactor, + ip_whitelist=self._ip_whitelist, + ip_blacklist=self._ip_blacklist, + ) @defer.inlineCallbacks def request(self, method, uri, data=b'', headers=None): + """ + Args: + method (str): HTTP method to use. + uri (str): URI to query. + data (bytes): Data to send in the request body, if applicable. + headers (t.w.http_headers.Headers): Request headers. + + Raises: + SynapseError: If the IP is blacklisted. + """ # A small wrapper around self.agent.request() so we can easily attach # counters to it outgoing_requests_counter.labels(method).inc() @@ -97,25 +266,34 @@ class SimpleHttpClient(object): try: request_deferred = treq.request( - method, uri, agent=self.agent, data=data, headers=headers + method, + uri, + agent=self.agent, + data=data, + headers=headers, + **self._extra_treq_args ) request_deferred = timeout_deferred( - request_deferred, 60, self.hs.get_reactor(), + request_deferred, + 60, + self.hs.get_reactor(), cancelled_to_request_timed_out_error, ) response = yield make_deferred_yieldable(request_deferred) incoming_responses_counter.labels(method, response.code).inc() logger.info( - "Received response to %s %s: %s", - method, redact_uri(uri), response.code + "Received response to %s %s: %s", method, redact_uri(uri), response.code ) defer.returnValue(response) except Exception as e: incoming_responses_counter.labels(method, "ERR").inc() logger.info( "Error sending request to %s %s: %s %s", - method, redact_uri(uri), type(e).__name__, e.args[0] + method, + redact_uri(uri), + type(e).__name__, + e.args[0], ) raise @@ -140,8 +318,9 @@ class SimpleHttpClient(object): # TODO: Do we ever want to log message contents? logger.debug("post_urlencoded_get_json args: %s", args) - query_bytes = urllib.parse.urlencode( - encode_urlencode_args(args), True).encode("utf8") + query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode( + "utf8" + ) actual_headers = { b"Content-Type": [b"application/x-www-form-urlencoded"], @@ -151,10 +330,7 @@ class SimpleHttpClient(object): actual_headers.update(headers) response = yield self.request( - "POST", - uri, - headers=Headers(actual_headers), - data=query_bytes + "POST", uri, headers=Headers(actual_headers), data=query_bytes ) if 200 <= response.code < 300: @@ -193,10 +369,7 @@ class SimpleHttpClient(object): actual_headers.update(headers) response = yield self.request( - "POST", - uri, - headers=Headers(actual_headers), - data=json_str + "POST", uri, headers=Headers(actual_headers), data=json_str ) body = yield make_deferred_yieldable(readBody(response)) @@ -264,10 +437,7 @@ class SimpleHttpClient(object): actual_headers.update(headers) response = yield self.request( - "PUT", - uri, - headers=Headers(actual_headers), - data=json_str + "PUT", uri, headers=Headers(actual_headers), data=json_str ) body = yield make_deferred_yieldable(readBody(response)) @@ -299,17 +469,11 @@ class SimpleHttpClient(object): query_bytes = urllib.parse.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) - actual_headers = { - b"User-Agent": [self.user_agent], - } + actual_headers = {b"User-Agent": [self.user_agent]} if headers: actual_headers.update(headers) - response = yield self.request( - "GET", - uri, - headers=Headers(actual_headers), - ) + response = yield self.request("GET", uri, headers=Headers(actual_headers)) body = yield make_deferred_yieldable(readBody(response)) @@ -334,22 +498,18 @@ class SimpleHttpClient(object): headers, absolute URI of the response and HTTP response code. """ - actual_headers = { - b"User-Agent": [self.user_agent], - } + actual_headers = {b"User-Agent": [self.user_agent]} if headers: actual_headers.update(headers) - response = yield self.request( - "GET", - url, - headers=Headers(actual_headers), - ) + response = yield self.request("GET", url, headers=Headers(actual_headers)) resp_headers = dict(response.headers.getAllRawHeaders()) - if (b'Content-Length' in resp_headers and - int(resp_headers[b'Content-Length']) > max_size): + if ( + b'Content-Length' in resp_headers + and int(resp_headers[b'Content-Length'][0]) > max_size + ): logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, @@ -359,26 +519,20 @@ class SimpleHttpClient(object): if response.code > 299: logger.warn("Got %d when downloading %s" % (response.code, url)) - raise SynapseError( - 502, - "Got error %d" % (response.code,), - Codes.UNKNOWN, - ) + raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) # TODO: if our Content-Type is HTML or something, just read the first # N bytes into RAM rather than saving it all to disk only to read it # straight back in again try: - length = yield make_deferred_yieldable(_readBodyToFile( - response, output_stream, max_size, - )) + length = yield make_deferred_yieldable( + _readBodyToFile(response, output_stream, max_size) + ) except Exception as e: logger.exception("Failed to download body") raise SynapseError( - 502, - ("Failed to download remote body: %s" % e), - Codes.UNKNOWN, + 502, ("Failed to download remote body: %s" % e), Codes.UNKNOWN ) defer.returnValue( @@ -387,13 +541,14 @@ class SimpleHttpClient(object): resp_headers, response.request.absoluteURI.decode('ascii'), response.code, - ), + ) ) # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. + class _ReadBodyToFileProtocol(protocol.Protocol): def __init__(self, stream, deferred, max_size): self.stream = stream @@ -405,11 +560,13 @@ class _ReadBodyToFileProtocol(protocol.Protocol): self.stream.write(data) self.length += len(data) if self.max_size is not None and self.length >= self.max_size: - self.deferred.errback(SynapseError( - 502, - "Requested file is too large > %r bytes" % (self.max_size,), - Codes.TOO_LARGE, - )) + self.deferred.errback( + SynapseError( + 502, + "Requested file is too large > %r bytes" % (self.max_size,), + Codes.TOO_LARGE, + ) + ) self.deferred = defer.Deferred() self.transport.loseConnection() @@ -427,6 +584,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol): # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. + def _readBodyToFile(response, stream, max_size): d = defer.Deferred() response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) @@ -449,10 +607,12 @@ class CaptchaServerHttpClient(SimpleHttpClient): "POST", url, data=query_bytes, - headers=Headers({ - b"Content-Type": [b"application/x-www-form-urlencoded"], - b"User-Agent": [self.user_agent], - }) + headers=Headers( + { + b"Content-Type": [b"application/x-www-form-urlencoded"], + b"User-Agent": [self.user_agent], + } + ), ) try: @@ -463,57 +623,6 @@ class CaptchaServerHttpClient(SimpleHttpClient): defer.returnValue(e.response) -class SpiderEndpointFactory(object): - def __init__(self, hs): - self.blacklist = hs.config.url_preview_ip_range_blacklist - self.whitelist = hs.config.url_preview_ip_range_whitelist - self.policyForHTTPS = hs.get_http_client_context_factory() - - def endpointForURI(self, uri): - logger.info("Getting endpoint for %s", uri.toBytes()) - - if uri.scheme == b"http": - endpoint_factory = HostnameEndpoint - elif uri.scheme == b"https": - tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) - - def endpoint_factory(reactor, host, port, **kw): - return wrapClientTLS( - tlsCreator, - HostnameEndpoint(reactor, host, port, **kw)) - else: - logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme) - return None - return SpiderEndpoint( - reactor, uri.host, uri.port, self.blacklist, self.whitelist, - endpoint=endpoint_factory, endpoint_kw_args=dict(timeout=15), - ) - - -class SpiderHttpClient(SimpleHttpClient): - """ - Separate HTTP client for spidering arbitrary URLs. - Special in that it follows retries and has a UA that looks - like a browser. - - used by the preview_url endpoint in the content repo. - """ - def __init__(self, hs): - SimpleHttpClient.__init__(self, hs) - # clobber the base class's agent and UA: - self.agent = ContentDecoderAgent( - BrowserLikeRedirectAgent( - Agent.usingEndpointFactory( - reactor, - SpiderEndpointFactory(hs) - ) - ), [(b'gzip', GzipDecoder)] - ) - # We could look like Chrome: - # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko) - # Chrome Safari" % hs.version_string) - - def encode_urlencode_args(args): return {k: encode_urlencode_arg(v) for k, v in args.items()} diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 91025037a3..f86a0b624e 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -218,41 +218,6 @@ class _WrappedConnection(object): return d -class SpiderEndpoint(object): - """An endpoint which refuses to connect to blacklisted IP addresses - Implements twisted.internet.interfaces.IStreamClientEndpoint. - """ - def __init__(self, reactor, host, port, blacklist, whitelist, - endpoint=HostnameEndpoint, endpoint_kw_args={}): - self.reactor = reactor - self.host = host - self.port = port - self.blacklist = blacklist - self.whitelist = whitelist - self.endpoint = endpoint - self.endpoint_kw_args = endpoint_kw_args - - @defer.inlineCallbacks - def connect(self, protocolFactory): - address = yield self.reactor.resolve(self.host) - - from netaddr import IPAddress - ip_address = IPAddress(address) - - if ip_address in self.blacklist: - if self.whitelist is None or ip_address not in self.whitelist: - raise ConnectError( - "Refusing to spider blacklisted IP address %s" % address - ) - - logger.info("Connecting to %s:%s", address, self.port) - endpoint = self.endpoint( - self.reactor, address, self.port, **self.endpoint_kw_args - ) - connection = yield endpoint.connect(protocolFactory) - defer.returnValue(connection) - - class SRVClientEndpoint(object): """An endpoint which looks up SRV records for a service. Cycles through the list of servers starting with each call to connect diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 72a92cc462..2c65ef5856 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -15,175 +15,121 @@ # limitations under the License. import logging -from distutils.version import LooseVersion + +from pkg_resources import DistributionNotFound, VersionConflict, get_distribution logger = logging.getLogger(__name__) -# this dict maps from python package name to a list of modules we expect it to -# provide. -# -# the key is a "requirement specifier", as used as a parameter to `pip -# install`[1], or an `install_requires` argument to `setuptools.setup` [2]. + +# REQUIREMENTS is a simple list of requirement specifiers[1], and must be +# installed. It is passed to setup() as install_requires in setup.py. # -# the value is a sequence of strings; each entry should be the name of the -# python module, optionally followed by a version assertion which can be either -# ">=<ver>" or "==<ver>". +# CONDITIONAL_REQUIREMENTS is the optional dependencies, represented as a dict +# of lists. The dict key is the optional dependency name and can be passed to +# pip when installing. The list is a series of requirement specifiers[1] to be +# installed when that optional dependency requirement is specified. It is passed +# to setup() as extras_require in setup.py # # [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers. -# [2] https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-dependencies -REQUIREMENTS = { - "jsonschema>=2.5.1": ["jsonschema>=2.5.1"], - "frozendict>=1": ["frozendict"], - "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], - "canonicaljson>=1.1.3": ["canonicaljson>=1.1.3"], - "signedjson>=1.0.0": ["signedjson>=1.0.0"], - "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"], - "service_identity>=16.0.0": ["service_identity>=16.0.0"], - "Twisted>=17.1.0": ["twisted>=17.1.0"], - "treq>=15.1": ["treq>=15.1"], +REQUIREMENTS = [ + "jsonschema>=2.5.1", + "frozendict>=1", + "unpaddedbase64>=1.1.0", + "canonicaljson>=1.1.3", + "signedjson>=1.0.0", + "pynacl>=1.2.1", + "service_identity>=16.0.0", + "Twisted>=17.1.0", + "treq>=15.1", # Twisted has required pyopenssl 16.0 since about Twisted 16.6. - "pyopenssl>=16.0.0": ["OpenSSL>=16.0.0"], - - "pyyaml>=3.11": ["yaml"], - "pyasn1>=0.1.9": ["pyasn1"], - "pyasn1-modules>=0.0.7": ["pyasn1_modules"], - "daemonize>=2.3.1": ["daemonize"], - "bcrypt>=3.1.0": ["bcrypt>=3.1.0"], - "pillow>=3.1.2": ["PIL"], - "sortedcontainers>=1.4.4": ["sortedcontainers"], - "psutil>=2.0.0": ["psutil>=2.0.0"], - "pymacaroons-pynacl>=0.9.3": ["pymacaroons"], - "msgpack-python>=0.4.2": ["msgpack"], - "phonenumbers>=8.2.0": ["phonenumbers"], - "six>=1.10": ["six"], - + "pyopenssl>=16.0.0", + "pyyaml>=3.11", + "pyasn1>=0.1.9", + "pyasn1-modules>=0.0.7", + "daemonize>=2.3.1", + "bcrypt>=3.1.0", + "pillow>=3.1.2", + "sortedcontainers>=1.4.4", + "psutil>=2.0.0", + "pymacaroons-pynacl>=0.9.3", + "msgpack-python>=0.4.2", + "phonenumbers>=8.2.0", + "six>=1.10", # prometheus_client 0.4.0 changed the format of counter metrics # (cf https://github.com/matrix-org/synapse/issues/4001) - "prometheus_client>=0.0.18,<0.4.0": ["prometheus_client"], - + "prometheus_client>=0.0.18,<0.4.0", # we use attr.s(slots), which arrived in 16.0.0 - "attrs>=16.0.0": ["attr>=16.0.0"], - "netaddr>=0.7.18": ["netaddr"], -} + "attrs>=16.0.0", + "netaddr>=0.7.18", +] CONDITIONAL_REQUIREMENTS = { - "email.enable_notifs": { - "Jinja2>=2.8": ["Jinja2>=2.8"], - "bleach>=1.4.2": ["bleach>=1.4.2"], - }, - "matrix-synapse-ldap3": { - "matrix-synapse-ldap3>=0.1": ["ldap_auth_provider"], - }, - "postgres": { - "psycopg2>=2.6": ["psycopg2"] - }, - "saml2": { - "pysaml2>=4.5.0": ["saml2"], - }, + "email.enable_notifs": ["Jinja2>=2.8", "bleach>=1.4.2"], + "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], + "postgres": ["psycopg2>=2.6"], + "saml2": ["pysaml2>=4.5.0"], + "url_preview": ["lxml>=3.5.0"], + "test": ["mock>=2.0"], } -def requirements(config=None, include_conditional=False): - reqs = REQUIREMENTS.copy() - if include_conditional: - for _, req in CONDITIONAL_REQUIREMENTS.items(): - reqs.update(req) - return reqs +def list_requirements(): + deps = set(REQUIREMENTS) + for opt in CONDITIONAL_REQUIREMENTS.values(): + deps = set(opt) | deps + return list(deps) -def github_link(project, version, egg): - return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg) +class DependencyException(Exception): + @property + def dependencies(self): + for i in self.args[0]: + yield '"' + i + '"' -DEPENDENCY_LINKS = { -} +def check_requirements(_get_distribution=get_distribution): + + deps_needed = [] + errors = [] -class MissingRequirementError(Exception): - def __init__(self, message, module_name, dependency): - super(MissingRequirementError, self).__init__(message) - self.module_name = module_name - self.dependency = dependency - - -def check_requirements(config=None): - """Checks that all the modules needed by synapse have been correctly - installed and are at the correct version""" - for dependency, module_requirements in ( - requirements(config, include_conditional=False).items()): - for module_requirement in module_requirements: - if ">=" in module_requirement: - module_name, required_version = module_requirement.split(">=") - version_test = ">=" - elif "==" in module_requirement: - module_name, required_version = module_requirement.split("==") - version_test = "==" - else: - module_name = module_requirement - version_test = None - - try: - module = __import__(module_name) - except ImportError: - logging.exception( - "Can't import %r which is part of %r", - module_name, dependency - ) - raise MissingRequirementError( - "Can't import %r which is part of %r" - % (module_name, dependency), module_name, dependency - ) - version = getattr(module, "__version__", None) - file_path = getattr(module, "__file__", None) - logger.info( - "Using %r version %r from %r to satisfy %r", - module_name, version, file_path, dependency + # Check the base dependencies exist -- they all must be installed. + for dependency in REQUIREMENTS: + try: + _get_distribution(dependency) + except VersionConflict as e: + deps_needed.append(dependency) + errors.append( + "Needed %s, got %s==%s" + % (dependency, e.dist.project_name, e.dist.version) ) + except DistributionNotFound: + deps_needed.append(dependency) + errors.append("Needed %s but it was not installed" % (dependency,)) - if version_test == ">=": - if version is None: - raise MissingRequirementError( - "Version of %r isn't set as __version__ of module %r" - % (dependency, module_name), module_name, dependency - ) - if LooseVersion(version) < LooseVersion(required_version): - raise MissingRequirementError( - "Version of %r in %r is too old. %r < %r" - % (dependency, file_path, version, required_version), - module_name, dependency - ) - elif version_test == "==": - if version is None: - raise MissingRequirementError( - "Version of %r isn't set as __version__ of module %r" - % (dependency, module_name), module_name, dependency - ) - if LooseVersion(version) != LooseVersion(required_version): - raise MissingRequirementError( - "Unexpected version of %r in %r. %r != %r" - % (dependency, file_path, version, required_version), - module_name, dependency - ) + # Check the optional dependencies are up to date. We allow them to not be + # installed. + OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) + for dependency in OPTS: + try: + _get_distribution(dependency) + except VersionConflict: + deps_needed.append(dependency) + errors.append("Needed %s but it was not installed" % (dependency,)) + except DistributionNotFound: + # If it's not found, we don't care + pass -def list_requirements(): - result = [] - linked = [] - for link in DEPENDENCY_LINKS.values(): - egg = link.split("#egg=")[1] - linked.append(egg.split('-')[0]) - result.append(link) - for requirement in requirements(include_conditional=True): - is_linked = False - for link in linked: - if requirement.replace('-', '_').startswith(link): - is_linked = True - if not is_linked: - result.append(requirement) - return result + if deps_needed: + for e in errors: + logging.exception(e) + + raise DependencyException(deps_needed) if __name__ == "__main__": import sys + sys.stdout.writelines(req + "\n" for req in list_requirements()) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index d0ecf241b6..ba3ab1d37d 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -35,7 +35,7 @@ from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, SynapseError -from synapse.http.client import SpiderHttpClient +from synapse.http.client import SimpleHttpClient from synapse.http.server import ( respond_with_json, respond_with_json_bytes, @@ -69,7 +69,12 @@ class PreviewUrlResource(Resource): self.max_spider_size = hs.config.max_spider_size self.server_name = hs.hostname self.store = hs.get_datastore() - self.client = SpiderHttpClient(hs) + self.client = SimpleHttpClient( + hs, + treq_args={"browser_like_redirects": True}, + ip_whitelist=hs.config.url_preview_ip_range_whitelist, + ip_blacklist=hs.config.url_preview_ip_range_blacklist, + ) self.media_repo = media_repo self.primary_base_path = media_repo.primary_base_path self.media_storage = media_storage @@ -318,6 +323,11 @@ class PreviewUrlResource(Resource): length, headers, uri, code = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size, ) + except SynapseError: + # Pass SynapseErrors through directly, so that the servlet + # handler will return a SynapseError to the client instead of + # blank data or a 500. + raise except Exception as e: # FIXME: pass through 404s and other error messages nicely logger.warn("Error downloading %s: %r", url, e) |