From a51288e5d6878da8ac889c8a503a6791e5b2ac43 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 15 Nov 2018 10:50:08 -0600 Subject: Add a coveragerc (#4180) --- .coveragerc | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 .coveragerc (limited to '.coveragerc') diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000000..ca333961f3 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,12 @@ +[run] +branch = True +parallel = True +source = synapse + +[paths] +source= + coverage + +[report] +precision = 2 +ignore_errors = True -- cgit 1.4.1 From ea6abf6724f4572a709047034fe2a672641068b9 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 22 Dec 2018 01:56:13 +1100 Subject: Fix IP URL previews on Python 3 (#4215) --- .coveragerc | 1 - .gitignore | 2 - changelog.d/4215.misc | 1 + synapse/http/client.py | 377 +++++++++++++++-------- synapse/http/endpoint.py | 35 --- synapse/rest/media/v1/preview_url_resource.py | 14 +- tests/rest/media/v1/test_url_preview.py | 424 ++++++++++++++++++++------ tests/server.py | 8 + 8 files changed, 590 insertions(+), 272 deletions(-) create mode 100644 changelog.d/4215.misc (limited to '.coveragerc') diff --git a/.coveragerc b/.coveragerc index ca333961f3..9873a30738 100644 --- a/.coveragerc +++ b/.coveragerc @@ -9,4 +9,3 @@ source= [report] precision = 2 -ignore_errors = True diff --git a/.gitignore b/.gitignore index 1b632646bb..d739595c3a 100644 --- a/.gitignore +++ b/.gitignore @@ -59,7 +59,5 @@ env/ .vscode/ .ropeproject/ - *.deb - /debs diff --git a/changelog.d/4215.misc b/changelog.d/4215.misc new file mode 100644 index 0000000000..bb90594836 --- /dev/null +++ b/changelog.d/4215.misc @@ -0,0 +1 @@ +Getting URL previews of IP addresses no longer fails on Python 3. diff --git a/synapse/http/client.py b/synapse/http/client.py index 3d05f83b8c..afcf698b29 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -21,28 +21,25 @@ from six.moves import urllib import treq from canonicaljson import encode_canonical_json, json +from netaddr import IPAddress from prometheus_client import Counter +from zope.interface import implementer, provider from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE -from twisted.internet import defer, protocol, reactor, ssl -from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS -from twisted.web._newclient import ResponseDone -from twisted.web.client import ( - Agent, - BrowserLikeRedirectAgent, - ContentDecoderAgent, - GzipDecoder, - HTTPConnectionPool, - PartialDownloadError, - readBody, +from twisted.internet import defer, protocol, ssl +from twisted.internet.interfaces import ( + IReactorPluggableNameResolver, + IResolutionReceiver, ) +from twisted.python.failure import Failure +from twisted.web._newclient import ResponseDone +from twisted.web.client import Agent, HTTPConnectionPool, PartialDownloadError, readBody from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import cancelled_to_request_timed_out_error, redact_uri -from synapse.http.endpoint import SpiderEndpoint from synapse.util.async_helpers import timeout_deferred from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.logcontext import make_deferred_yieldable @@ -50,8 +47,125 @@ from synapse.util.logcontext import make_deferred_yieldable logger = logging.getLogger(__name__) outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"]) -incoming_responses_counter = Counter("synapse_http_client_responses", "", - ["method", "code"]) +incoming_responses_counter = Counter( + "synapse_http_client_responses", "", ["method", "code"] +) + + +def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist): + """ + Args: + ip_address (netaddr.IPAddress) + ip_whitelist (netaddr.IPSet) + ip_blacklist (netaddr.IPSet) + """ + if ip_address in ip_blacklist: + if ip_whitelist is None or ip_address not in ip_whitelist: + return True + return False + + +class IPBlacklistingResolver(object): + """ + A proxy for reactor.nameResolver which only produces non-blacklisted IP + addresses, preventing DNS rebinding attacks on URL preview. + """ + + def __init__(self, reactor, ip_whitelist, ip_blacklist): + """ + Args: + reactor (twisted.internet.reactor) + ip_whitelist (netaddr.IPSet) + ip_blacklist (netaddr.IPSet) + """ + self._reactor = reactor + self._ip_whitelist = ip_whitelist + self._ip_blacklist = ip_blacklist + + def resolveHostName(self, recv, hostname, portNumber=0): + + r = recv() + d = defer.Deferred() + addresses = [] + + @provider(IResolutionReceiver) + class EndpointReceiver(object): + @staticmethod + def resolutionBegan(resolutionInProgress): + pass + + @staticmethod + def addressResolved(address): + ip_address = IPAddress(address.host) + + if check_against_blacklist( + ip_address, self._ip_whitelist, self._ip_blacklist + ): + logger.info( + "Dropped %s from DNS resolution to %s" % (ip_address, hostname) + ) + raise SynapseError(403, "IP address blocked by IP blacklist entry") + + addresses.append(address) + + @staticmethod + def resolutionComplete(): + d.callback(addresses) + + self._reactor.nameResolver.resolveHostName( + EndpointReceiver, hostname, portNumber=portNumber + ) + + def _callback(addrs): + r.resolutionBegan(None) + for i in addrs: + r.addressResolved(i) + r.resolutionComplete() + + d.addCallback(_callback) + + return r + + +class BlacklistingAgentWrapper(Agent): + """ + An Agent wrapper which will prevent access to IP addresses being accessed + directly (without an IP address lookup). + """ + + def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None): + """ + Args: + agent (twisted.web.client.Agent): The Agent to wrap. + reactor (twisted.internet.reactor) + ip_whitelist (netaddr.IPSet) + ip_blacklist (netaddr.IPSet) + """ + self._agent = agent + self._ip_whitelist = ip_whitelist + self._ip_blacklist = ip_blacklist + + def request(self, method, uri, headers=None, bodyProducer=None): + h = urllib.parse.urlparse(uri.decode('ascii')) + + try: + ip_address = IPAddress(h.hostname) + + if check_against_blacklist( + ip_address, self._ip_whitelist, self._ip_blacklist + ): + logger.info( + "Blocking access to %s because of blacklist" % (ip_address,) + ) + e = SynapseError(403, "IP address blocked by IP blacklist entry") + return defer.fail(Failure(e)) + except Exception: + # Not an IP + pass + + return self._agent.request( + method, uri, headers=headers, bodyProducer=bodyProducer + ) class SimpleHttpClient(object): @@ -59,14 +173,54 @@ class SimpleHttpClient(object): A simple, no-frills HTTP client with methods that wrap up common ways of using HTTP in Matrix """ - def __init__(self, hs): + + def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None): + """ + Args: + hs (synapse.server.HomeServer) + treq_args (dict): Extra keyword arguments to be given to treq.request. + ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that + we may not request. + ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can + request if it were otherwise caught in a blacklist. + """ self.hs = hs - pool = HTTPConnectionPool(reactor) + self._ip_whitelist = ip_whitelist + self._ip_blacklist = ip_blacklist + self._extra_treq_args = treq_args + + self.user_agent = hs.version_string + self.clock = hs.get_clock() + if hs.config.user_agent_suffix: + self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) + + self.user_agent = self.user_agent.encode('ascii') + + if self._ip_blacklist: + real_reactor = hs.get_reactor() + # If we have an IP blacklist, we need to use a DNS resolver which + # filters out blacklisted IP addresses, to prevent DNS rebinding. + nameResolver = IPBlacklistingResolver( + real_reactor, self._ip_whitelist, self._ip_blacklist + ) + + @implementer(IReactorPluggableNameResolver) + class Reactor(object): + def __getattr__(_self, attr): + if attr == "nameResolver": + return nameResolver + else: + return getattr(real_reactor, attr) + + self.reactor = Reactor() + else: + self.reactor = hs.get_reactor() # the pusher makes lots of concurrent SSL connections to sygnal, and - # tends to do so in batches, so we need to allow the pool to keep lots - # of idle connections around. + # tends to do so in batches, so we need to allow the pool to keep + # lots of idle connections around. + pool = HTTPConnectionPool(self.reactor) pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) pool.cachedConnectionTimeout = 2 * 60 @@ -74,20 +228,35 @@ class SimpleHttpClient(object): # BrowserLikePolicyForHTTPS which will do regular cert validation # 'like a browser' self.agent = Agent( - reactor, + self.reactor, connectTimeout=15, - contextFactory=hs.get_http_client_context_factory(), + contextFactory=self.hs.get_http_client_context_factory(), pool=pool, ) - self.user_agent = hs.version_string - self.clock = hs.get_clock() - if hs.config.user_agent_suffix: - self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) - self.user_agent = self.user_agent.encode('ascii') + if self._ip_blacklist: + # If we have an IP blacklist, we then install the blacklisting Agent + # which prevents direct access to IP addresses, that are not caught + # by the DNS resolution. + self.agent = BlacklistingAgentWrapper( + self.agent, + self.reactor, + ip_whitelist=self._ip_whitelist, + ip_blacklist=self._ip_blacklist, + ) @defer.inlineCallbacks def request(self, method, uri, data=b'', headers=None): + """ + Args: + method (str): HTTP method to use. + uri (str): URI to query. + data (bytes): Data to send in the request body, if applicable. + headers (t.w.http_headers.Headers): Request headers. + + Raises: + SynapseError: If the IP is blacklisted. + """ # A small wrapper around self.agent.request() so we can easily attach # counters to it outgoing_requests_counter.labels(method).inc() @@ -97,25 +266,34 @@ class SimpleHttpClient(object): try: request_deferred = treq.request( - method, uri, agent=self.agent, data=data, headers=headers + method, + uri, + agent=self.agent, + data=data, + headers=headers, + **self._extra_treq_args ) request_deferred = timeout_deferred( - request_deferred, 60, self.hs.get_reactor(), + request_deferred, + 60, + self.hs.get_reactor(), cancelled_to_request_timed_out_error, ) response = yield make_deferred_yieldable(request_deferred) incoming_responses_counter.labels(method, response.code).inc() logger.info( - "Received response to %s %s: %s", - method, redact_uri(uri), response.code + "Received response to %s %s: %s", method, redact_uri(uri), response.code ) defer.returnValue(response) except Exception as e: incoming_responses_counter.labels(method, "ERR").inc() logger.info( "Error sending request to %s %s: %s %s", - method, redact_uri(uri), type(e).__name__, e.args[0] + method, + redact_uri(uri), + type(e).__name__, + e.args[0], ) raise @@ -140,8 +318,9 @@ class SimpleHttpClient(object): # TODO: Do we ever want to log message contents? logger.debug("post_urlencoded_get_json args: %s", args) - query_bytes = urllib.parse.urlencode( - encode_urlencode_args(args), True).encode("utf8") + query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode( + "utf8" + ) actual_headers = { b"Content-Type": [b"application/x-www-form-urlencoded"], @@ -151,10 +330,7 @@ class SimpleHttpClient(object): actual_headers.update(headers) response = yield self.request( - "POST", - uri, - headers=Headers(actual_headers), - data=query_bytes + "POST", uri, headers=Headers(actual_headers), data=query_bytes ) if 200 <= response.code < 300: @@ -193,10 +369,7 @@ class SimpleHttpClient(object): actual_headers.update(headers) response = yield self.request( - "POST", - uri, - headers=Headers(actual_headers), - data=json_str + "POST", uri, headers=Headers(actual_headers), data=json_str ) body = yield make_deferred_yieldable(readBody(response)) @@ -264,10 +437,7 @@ class SimpleHttpClient(object): actual_headers.update(headers) response = yield self.request( - "PUT", - uri, - headers=Headers(actual_headers), - data=json_str + "PUT", uri, headers=Headers(actual_headers), data=json_str ) body = yield make_deferred_yieldable(readBody(response)) @@ -299,17 +469,11 @@ class SimpleHttpClient(object): query_bytes = urllib.parse.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) - actual_headers = { - b"User-Agent": [self.user_agent], - } + actual_headers = {b"User-Agent": [self.user_agent]} if headers: actual_headers.update(headers) - response = yield self.request( - "GET", - uri, - headers=Headers(actual_headers), - ) + response = yield self.request("GET", uri, headers=Headers(actual_headers)) body = yield make_deferred_yieldable(readBody(response)) @@ -334,22 +498,18 @@ class SimpleHttpClient(object): headers, absolute URI of the response and HTTP response code. """ - actual_headers = { - b"User-Agent": [self.user_agent], - } + actual_headers = {b"User-Agent": [self.user_agent]} if headers: actual_headers.update(headers) - response = yield self.request( - "GET", - url, - headers=Headers(actual_headers), - ) + response = yield self.request("GET", url, headers=Headers(actual_headers)) resp_headers = dict(response.headers.getAllRawHeaders()) - if (b'Content-Length' in resp_headers and - int(resp_headers[b'Content-Length']) > max_size): + if ( + b'Content-Length' in resp_headers + and int(resp_headers[b'Content-Length'][0]) > max_size + ): logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, @@ -359,26 +519,20 @@ class SimpleHttpClient(object): if response.code > 299: logger.warn("Got %d when downloading %s" % (response.code, url)) - raise SynapseError( - 502, - "Got error %d" % (response.code,), - Codes.UNKNOWN, - ) + raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) # TODO: if our Content-Type is HTML or something, just read the first # N bytes into RAM rather than saving it all to disk only to read it # straight back in again try: - length = yield make_deferred_yieldable(_readBodyToFile( - response, output_stream, max_size, - )) + length = yield make_deferred_yieldable( + _readBodyToFile(response, output_stream, max_size) + ) except Exception as e: logger.exception("Failed to download body") raise SynapseError( - 502, - ("Failed to download remote body: %s" % e), - Codes.UNKNOWN, + 502, ("Failed to download remote body: %s" % e), Codes.UNKNOWN ) defer.returnValue( @@ -387,13 +541,14 @@ class SimpleHttpClient(object): resp_headers, response.request.absoluteURI.decode('ascii'), response.code, - ), + ) ) # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. + class _ReadBodyToFileProtocol(protocol.Protocol): def __init__(self, stream, deferred, max_size): self.stream = stream @@ -405,11 +560,13 @@ class _ReadBodyToFileProtocol(protocol.Protocol): self.stream.write(data) self.length += len(data) if self.max_size is not None and self.length >= self.max_size: - self.deferred.errback(SynapseError( - 502, - "Requested file is too large > %r bytes" % (self.max_size,), - Codes.TOO_LARGE, - )) + self.deferred.errback( + SynapseError( + 502, + "Requested file is too large > %r bytes" % (self.max_size,), + Codes.TOO_LARGE, + ) + ) self.deferred = defer.Deferred() self.transport.loseConnection() @@ -427,6 +584,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol): # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. + def _readBodyToFile(response, stream, max_size): d = defer.Deferred() response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) @@ -449,10 +607,12 @@ class CaptchaServerHttpClient(SimpleHttpClient): "POST", url, data=query_bytes, - headers=Headers({ - b"Content-Type": [b"application/x-www-form-urlencoded"], - b"User-Agent": [self.user_agent], - }) + headers=Headers( + { + b"Content-Type": [b"application/x-www-form-urlencoded"], + b"User-Agent": [self.user_agent], + } + ), ) try: @@ -463,57 +623,6 @@ class CaptchaServerHttpClient(SimpleHttpClient): defer.returnValue(e.response) -class SpiderEndpointFactory(object): - def __init__(self, hs): - self.blacklist = hs.config.url_preview_ip_range_blacklist - self.whitelist = hs.config.url_preview_ip_range_whitelist - self.policyForHTTPS = hs.get_http_client_context_factory() - - def endpointForURI(self, uri): - logger.info("Getting endpoint for %s", uri.toBytes()) - - if uri.scheme == b"http": - endpoint_factory = HostnameEndpoint - elif uri.scheme == b"https": - tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) - - def endpoint_factory(reactor, host, port, **kw): - return wrapClientTLS( - tlsCreator, - HostnameEndpoint(reactor, host, port, **kw)) - else: - logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme) - return None - return SpiderEndpoint( - reactor, uri.host, uri.port, self.blacklist, self.whitelist, - endpoint=endpoint_factory, endpoint_kw_args=dict(timeout=15), - ) - - -class SpiderHttpClient(SimpleHttpClient): - """ - Separate HTTP client for spidering arbitrary URLs. - Special in that it follows retries and has a UA that looks - like a browser. - - used by the preview_url endpoint in the content repo. - """ - def __init__(self, hs): - SimpleHttpClient.__init__(self, hs) - # clobber the base class's agent and UA: - self.agent = ContentDecoderAgent( - BrowserLikeRedirectAgent( - Agent.usingEndpointFactory( - reactor, - SpiderEndpointFactory(hs) - ) - ), [(b'gzip', GzipDecoder)] - ) - # We could look like Chrome: - # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko) - # Chrome Safari" % hs.version_string) - - def encode_urlencode_args(args): return {k: encode_urlencode_arg(v) for k, v in args.items()} diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 91025037a3..f86a0b624e 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -218,41 +218,6 @@ class _WrappedConnection(object): return d -class SpiderEndpoint(object): - """An endpoint which refuses to connect to blacklisted IP addresses - Implements twisted.internet.interfaces.IStreamClientEndpoint. - """ - def __init__(self, reactor, host, port, blacklist, whitelist, - endpoint=HostnameEndpoint, endpoint_kw_args={}): - self.reactor = reactor - self.host = host - self.port = port - self.blacklist = blacklist - self.whitelist = whitelist - self.endpoint = endpoint - self.endpoint_kw_args = endpoint_kw_args - - @defer.inlineCallbacks - def connect(self, protocolFactory): - address = yield self.reactor.resolve(self.host) - - from netaddr import IPAddress - ip_address = IPAddress(address) - - if ip_address in self.blacklist: - if self.whitelist is None or ip_address not in self.whitelist: - raise ConnectError( - "Refusing to spider blacklisted IP address %s" % address - ) - - logger.info("Connecting to %s:%s", address, self.port) - endpoint = self.endpoint( - self.reactor, address, self.port, **self.endpoint_kw_args - ) - connection = yield endpoint.connect(protocolFactory) - defer.returnValue(connection) - - class SRVClientEndpoint(object): """An endpoint which looks up SRV records for a service. Cycles through the list of servers starting with each call to connect diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index d0ecf241b6..ba3ab1d37d 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -35,7 +35,7 @@ from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, SynapseError -from synapse.http.client import SpiderHttpClient +from synapse.http.client import SimpleHttpClient from synapse.http.server import ( respond_with_json, respond_with_json_bytes, @@ -69,7 +69,12 @@ class PreviewUrlResource(Resource): self.max_spider_size = hs.config.max_spider_size self.server_name = hs.hostname self.store = hs.get_datastore() - self.client = SpiderHttpClient(hs) + self.client = SimpleHttpClient( + hs, + treq_args={"browser_like_redirects": True}, + ip_whitelist=hs.config.url_preview_ip_range_whitelist, + ip_blacklist=hs.config.url_preview_ip_range_blacklist, + ) self.media_repo = media_repo self.primary_base_path = media_repo.primary_base_path self.media_storage = media_storage @@ -318,6 +323,11 @@ class PreviewUrlResource(Resource): length, headers, uri, code = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size, ) + except SynapseError: + # Pass SynapseErrors through directly, so that the servlet + # handler will return a SynapseError to the client instead of + # blank data or a 500. + raise except Exception as e: # FIXME: pass through 404s and other error messages nicely logger.warn("Error downloading %s: %r", url, e) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index c62f71b44a..650ce95a6f 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -15,21 +15,55 @@ import os -from mock import Mock +import attr +from netaddr import IPSet -from twisted.internet.defer import Deferred +from twisted.internet._resolver import HostResolution +from twisted.internet.address import IPv4Address, IPv6Address +from twisted.internet.error import DNSLookupError +from twisted.python.failure import Failure +from twisted.test.proto_helpers import AccumulatingProtocol +from twisted.web._newclient import ResponseDone from synapse.config.repository import MediaStorageProviderConfig -from synapse.util.logcontext import make_deferred_yieldable from synapse.util.module_loader import load_module from tests import unittest +from tests.server import FakeTransport + + +@attr.s +class FakeResponse(object): + version = attr.ib() + code = attr.ib() + phrase = attr.ib() + headers = attr.ib() + body = attr.ib() + absoluteURI = attr.ib() + + @property + def request(self): + @attr.s + class FakeTransport(object): + absoluteURI = self.absoluteURI + + return FakeTransport() + + def deliverBody(self, protocol): + protocol.dataReceived(self.body) + protocol.connectionLost(Failure(ResponseDone())) class URLPreviewTests(unittest.HomeserverTestCase): hijack_auth = True user_id = "@test:user" + end_content = ( + b'' + b'' + b'' + b'' + ) def make_homeserver(self, reactor, clock): @@ -39,6 +73,15 @@ class URLPreviewTests(unittest.HomeserverTestCase): config = self.default_config() config.url_preview_enabled = True config.max_spider_size = 9999999 + config.url_preview_ip_range_blacklist = IPSet( + ( + "192.168.1.1", + "1.0.0.0/8", + "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", + "2001:800::/21", + ) + ) + config.url_preview_ip_range_whitelist = IPSet(("1.1.1.1",)) config.url_preview_url_blacklist = [] config.media_store_path = self.storage_path @@ -62,63 +105,50 @@ class URLPreviewTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.fetches = [] + self.media_repo = hs.get_media_repository_resource() + self.preview_url = self.media_repo.children[b'preview_url'] - def get_file(url, output_stream, max_size): - """ - Returns tuple[int,dict,str,int] of file length, response headers, - absolute URI, and response code. - """ + self.lookups = {} - def write_to(r): - data, response = r - output_stream.write(data) - return response + class Resolver(object): + def resolveHostName( + _self, + resolutionReceiver, + hostName, + portNumber=0, + addressTypes=None, + transportSemantics='TCP', + ): - d = Deferred() - d.addCallback(write_to) - self.fetches.append((d, url)) - return make_deferred_yieldable(d) + resolution = HostResolution(hostName) + resolutionReceiver.resolutionBegan(resolution) + if hostName not in self.lookups: + raise DNSLookupError("OH NO") - client = Mock() - client.get_file = get_file + for i in self.lookups[hostName]: + resolutionReceiver.addressResolved(i[0]('TCP', i[1], portNumber)) + resolutionReceiver.resolutionComplete() + return resolutionReceiver - self.media_repo = hs.get_media_repository_resource() - preview_url = self.media_repo.children[b'preview_url'] - preview_url.client = client - self.preview_url = preview_url + self.reactor.nameResolver = Resolver() def test_cache_returns_correct_type(self): + self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False + "GET", "url_preview?url=http://matrix.org", shorthand=False ) request.render(self.preview_url) self.pump() - # We've made one fetch - self.assertEqual(len(self.fetches), 1) - - end_content = ( - b'' - b'' - b'' - b'' - ) - - self.fetches[0][0].callback( - ( - end_content, - ( - len(end_content), - { - b"Content-Length": [b"%d" % (len(end_content))], - b"Content-Type": [b'text/html; charset="utf8"'], - }, - "https://example.com", - 200, - ), - ) + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" + % (len(self.end_content),) + + self.end_content ) self.pump() @@ -129,14 +159,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Check the cache returns the correct response request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False + "GET", "url_preview?url=http://matrix.org", shorthand=False ) request.render(self.preview_url) self.pump() - # Only one fetch, still, since we'll lean on the cache - self.assertEqual(len(self.fetches), 1) - # Check the cache response has the same content self.assertEqual(channel.code, 200) self.assertEqual( @@ -144,20 +171,17 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) # Clear the in-memory cache - self.assertIn("matrix.org", self.preview_url._cache) - self.preview_url._cache.pop("matrix.org") - self.assertNotIn("matrix.org", self.preview_url._cache) + self.assertIn("http://matrix.org", self.preview_url._cache) + self.preview_url._cache.pop("http://matrix.org") + self.assertNotIn("http://matrix.org", self.preview_url._cache) # Check the database cache returns the correct response request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False + "GET", "url_preview?url=http://matrix.org", shorthand=False ) request.render(self.preview_url) self.pump() - # Only one fetch, still, since we'll lean on the cache - self.assertEqual(len(self.fetches), 1) - # Check the cache response has the same content self.assertEqual(channel.code, 200) self.assertEqual( @@ -165,78 +189,282 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) def test_non_ascii_preview_httpequiv(self): + self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + + end_content = ( + b'' + b'' + b'' + b'' + b'' + ) request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False + "GET", "url_preview?url=http://matrix.org", shorthand=False ) request.render(self.preview_url) self.pump() - # We've made one fetch - self.assertEqual(len(self.fetches), 1) + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b"Content-Type: text/html; charset=\"utf8\"\r\n\r\n" + ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") + + def test_non_ascii_preview_content_type(self): + self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] end_content = ( b'' - b'' b'' b'' b'' ) - self.fetches[0][0].callback( + request, channel = self.make_request( + "GET", "url_preview?url=http://matrix.org", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( ( - end_content, - ( - len(end_content), - { - b"Content-Length": [b"%d" % (len(end_content))], - # This charset=utf-8 should be ignored, because the - # document has a meta tag overriding it. - b"Content-Type": [b'text/html; charset="utf8"'], - }, - "https://example.com", - 200, - ), + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b"Content-Type: text/html; charset=\"windows-1251\"\r\n\r\n" ) + % (len(end_content),) + + end_content ) self.pump() self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") - def test_non_ascii_preview_content_type(self): + def test_ipaddr(self): + """ + IP addresses can be previewed directly. + """ + self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")] request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False + "GET", "url_preview?url=http://example.com", shorthand=False ) request.render(self.preview_url) self.pump() - # We've made one fetch - self.assertEqual(len(self.fetches), 1) + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" + % (len(self.end_content),) + + self.end_content + ) - end_content = ( - b'' - b'' - b'' - b'' + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) - self.fetches[0][0].callback( - ( - end_content, - ( - len(end_content), - { - b"Content-Length": [b"%d" % (len(end_content))], - b"Content-Type": [b'text/html; charset="windows-1251"'], - }, - "https://example.com", - 200, - ), - ) + def test_blacklisted_ip_specific(self): + """ + Blacklisted IP addresses, found via DNS, are not spidered. + """ + self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")] + + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # No requests made. + self.assertEqual(len(self.reactor.tcpClients), 0) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) + + def test_blacklisted_ip_range(self): + """ + Blacklisted IP ranges, IPs found over DNS, are not spidered. + """ + self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")] + + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) + + def test_blacklisted_ip_specific_direct(self): + """ + Blacklisted IP addresses, accessed directly, are not spidered. + """ + request, channel = self.make_request( + "GET", "url_preview?url=http://192.168.1.1", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # No requests made. + self.assertEqual(len(self.reactor.tcpClients), 0) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) + + def test_blacklisted_ip_range_direct(self): + """ + Blacklisted IP ranges, accessed directly, are not spidered. + """ + request, channel = self.make_request( + "GET", "url_preview?url=http://1.1.1.2", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) + + def test_blacklisted_ip_range_whitelisted_ip(self): + """ + Blacklisted but then subsequently whitelisted IP addresses can be + spidered. + """ + self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")] + + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" + % (len(self.end_content),) + + self.end_content ) self.pump() self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") + self.assertEqual( + channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} + ) + + def test_blacklisted_ip_with_external_ip(self): + """ + If a hostname resolves a blacklisted IP, even if there's a + non-blacklisted one, it will be rejected. + """ + # Hardcode the URL resolving to the IP we want. + self.lookups[u"example.com"] = [ + (IPv4Address, "1.1.1.2"), + (IPv4Address, "8.8.8.8"), + ] + + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) + + def test_blacklisted_ipv6_specific(self): + """ + Blacklisted IP addresses, found via DNS, are not spidered. + """ + self.lookups["example.com"] = [ + (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff") + ] + + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # No requests made. + self.assertEqual(len(self.reactor.tcpClients), 0) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) + + def test_blacklisted_ipv6_range(self): + """ + Blacklisted IP ranges, IPs found over DNS, are not spidered. + """ + self.lookups["example.com"] = [(IPv6Address, "2001:800::1")] + + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) diff --git a/tests/server.py b/tests/server.py index ceec2f2d4e..db43fa0db8 100644 --- a/tests/server.py +++ b/tests/server.py @@ -383,8 +383,16 @@ class FakeTransport(object): self.disconnecting = True def pauseProducing(self): + if not self.producer: + return + self.producer.pauseProducing() + def resumeProducing(self): + if not self.producer: + return + self.producer.resumeProducing() + def unregisterProducer(self): if not self.producer: return -- cgit 1.4.1 From 58f6c4818337364dd9c6bf01062e7b0dadcb8a25 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 24 Jan 2019 21:31:54 +1100 Subject: Use native UPSERTs where possible (#4306) --- .coveragerc | 6 +- .gitignore | 6 +- changelog.d/4306.misc | 1 + synapse/storage/_base.py | 148 +++++++++++++++++++++++++++++++++--- synapse/storage/client_ips.py | 5 +- synapse/storage/engines/__init__.py | 2 +- synapse/storage/engines/postgres.py | 14 ++++ synapse/storage/engines/sqlite.py | 96 +++++++++++++++++++++++ synapse/storage/engines/sqlite3.py | 87 --------------------- synapse/storage/pusher.py | 9 ++- synapse/storage/user_directory.py | 55 ++++++++++---- tests/storage/test_base.py | 1 + tests/test_server.py | 12 ++- tests/unittest.py | 12 ++- tox.ini | 1 + 15 files changed, 325 insertions(+), 130 deletions(-) create mode 100644 changelog.d/4306.misc create mode 100644 synapse/storage/engines/sqlite.py delete mode 100644 synapse/storage/engines/sqlite3.py (limited to '.coveragerc') diff --git a/.coveragerc b/.coveragerc index 9873a30738..e9460a340a 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,11 +1,7 @@ [run] branch = True parallel = True -source = synapse - -[paths] -source= - coverage +include = synapse/* [report] precision = 2 diff --git a/.gitignore b/.gitignore index d739595c3a..1033124f1d 100644 --- a/.gitignore +++ b/.gitignore @@ -25,9 +25,9 @@ homeserver*.pid *.tls.dh *.tls.key -.coverage -.coverage.* -!.coverage.rc +.coverage* +coverage.* +!.coveragerc htmlcov demo/*/*.db diff --git a/changelog.d/4306.misc b/changelog.d/4306.misc new file mode 100644 index 0000000000..58130b6190 --- /dev/null +++ b/changelog.d/4306.misc @@ -0,0 +1 @@ +Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 865b5e915a..254fdc04c6 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -192,6 +192,41 @@ class SQLBaseStore(object): self.database_engine = hs.database_engine + # A set of tables that are not safe to use native upserts in. + self._unsafe_to_upsert_tables = {"user_ips"} + + if self.database_engine.can_native_upsert: + # Check ASAP (and then later, every 1s) to see if we have finished + # background updates of tables that aren't safe to update. + self._clock.call_later(0.0, self._check_safe_to_upsert) + + @defer.inlineCallbacks + def _check_safe_to_upsert(self): + """ + Is it safe to use native UPSERT? + + If there are background updates, we will need to wait, as they may be + the addition of indexes that set the UNIQUE constraint that we require. + + If the background updates have not completed, wait a second and check again. + """ + updates = yield self._simple_select_list( + "background_updates", + keyvalues=None, + retcols=["update_name"], + desc="check_background_updates", + ) + updates = [x["update_name"] for x in updates] + + # The User IPs table in schema #53 was missing a unique index, which we + # run as a background update. + if "user_ips_device_unique_index" not in updates: + self._unsafe_to_upsert_tables.discard("user_id") + + # If there's any tables left to check, reschedule to run. + if self._unsafe_to_upsert_tables: + self._clock.call_later(1.0, self._check_safe_to_upsert) + def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() @@ -494,8 +529,15 @@ class SQLBaseStore(object): txn.executemany(sql, vals) @defer.inlineCallbacks - def _simple_upsert(self, table, keyvalues, values, - insertion_values={}, desc="_simple_upsert", lock=True): + def _simple_upsert( + self, + table, + keyvalues, + values, + insertion_values={}, + desc="_simple_upsert", + lock=True + ): """ `lock` should generally be set to True (the default), but can be set @@ -516,16 +558,21 @@ class SQLBaseStore(object): inserting lock (bool): True to lock the table when doing the upsert. Returns: - Deferred(bool): True if a new entry was created, False if an - existing one was updated. + Deferred(None or bool): Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. """ attempts = 0 while True: try: result = yield self.runInteraction( desc, - self._simple_upsert_txn, table, keyvalues, values, insertion_values, - lock=lock + self._simple_upsert_txn, + table, + keyvalues, + values, + insertion_values, + lock=lock, ) defer.returnValue(result) except self.database_engine.module.IntegrityError as e: @@ -537,12 +584,59 @@ class SQLBaseStore(object): # presumably we raced with another transaction: let's retry. logger.warn( - "IntegrityError when upserting into %s; retrying: %s", - table, e + "%s when upserting into %s; retrying: %s", e.__name__, table, e ) - def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}, - lock=True): + def _simple_upsert_txn( + self, + txn, + table, + keyvalues, + values, + insertion_values={}, + lock=True, + ): + """ + Pick the UPSERT method which works best on the platform. Either the + native one (Pg9.5+, recent SQLites), or fall back to an emulated method. + + Args: + txn: The transaction to use. + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + Deferred(None or bool): Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + if ( + self.database_engine.can_native_upsert + and table not in self._unsafe_to_upsert_tables + ): + return self._simple_upsert_txn_native_upsert( + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + ) + else: + return self._simple_upsert_txn_emulated( + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + lock=lock, + ) + + def _simple_upsert_txn_emulated( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): # We need to lock the table :(, unless we're *really* careful if lock: self.database_engine.lock_table(txn, table) @@ -577,12 +671,44 @@ class SQLBaseStore(object): sql = "INSERT INTO %s (%s) VALUES (%s)" % ( table, ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues) + ", ".join("?" for _ in allvalues), ) txn.execute(sql, list(allvalues.values())) # successfully inserted return True + def _simple_upsert_txn_native_upsert( + self, txn, table, keyvalues, values, insertion_values={} + ): + """ + Use the native UPSERT functionality in recent PostgreSQL versions. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + Returns: + None + """ + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + allvalues.update(insertion_values) + + sql = ( + "INSERT INTO %s (%s) VALUES (%s) " + "ON CONFLICT (%s) DO UPDATE SET %s" + ) % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ", ".join(k for k in keyvalues), + ", ".join(k + "=EXCLUDED." + k for k in values), + ) + txn.execute(sql, list(allvalues.values())) + def _simple_select_one(self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"): """Executes a SELECT query on the named table, which is expected to diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index b228a20ac2..091d7116c5 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -257,7 +257,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): ) def _update_client_ips_batch_txn(self, txn, to_update): - self.database_engine.lock_table(txn, "user_ips") + if "user_ips" in self._unsafe_to_upsert_tables or ( + not self.database_engine.can_native_upsert + ): + self.database_engine.lock_table(txn, "user_ips") for entry in iteritems(to_update): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index e2f9de8451..ff5ef97ca8 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -18,7 +18,7 @@ import platform from ._base import IncorrectDatabaseSetup from .postgres import PostgresEngine -from .sqlite3 import Sqlite3Engine +from .sqlite import Sqlite3Engine SUPPORTED_MODULE = { "sqlite3": Sqlite3Engine, diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 42225f8a2a..4004427c7b 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -38,6 +38,13 @@ class PostgresEngine(object): return sql.replace("?", "%s") def on_new_connection(self, db_conn): + + # Get the version of PostgreSQL that we're using. As per the psycopg2 + # docs: The number is formed by converting the major, minor, and + # revision numbers into two-decimal-digit numbers and appending them + # together. For example, version 8.1.5 will be returned as 80105 + self._version = db_conn.server_version + db_conn.set_isolation_level( self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) @@ -54,6 +61,13 @@ class PostgresEngine(object): cursor.close() + @property + def can_native_upsert(self): + """ + Can we use native UPSERTs? This requires PostgreSQL 9.5+. + """ + return self._version >= 90500 + def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py new file mode 100644 index 0000000000..c64d73ff21 --- /dev/null +++ b/synapse/storage/engines/sqlite.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import struct +import threading +from sqlite3 import sqlite_version_info + +from synapse.storage.prepare_database import prepare_database + + +class Sqlite3Engine(object): + single_threaded = True + + def __init__(self, database_module, database_config): + self.module = database_module + + # The current max state_group, or None if we haven't looked + # in the DB yet. + self._current_state_group_id = None + self._current_state_group_id_lock = threading.Lock() + + @property + def can_native_upsert(self): + """ + Do we support native UPSERTs? This requires SQLite3 3.24+, plus some + more work we haven't done yet to tell what was inserted vs updated. + """ + return sqlite_version_info >= (3, 24, 0) + + def check_database(self, txn): + pass + + def convert_param_style(self, sql): + return sql + + def on_new_connection(self, db_conn): + prepare_database(db_conn, self, config=None) + db_conn.create_function("rank", 1, _rank) + + def is_deadlock(self, error): + return False + + def is_connection_closed(self, conn): + return False + + def lock_table(self, txn, table): + return + + def get_next_state_group_id(self, txn): + """Returns an int that can be used as a new state_group ID + """ + # We do application locking here since if we're using sqlite then + # we are a single process synapse. + with self._current_state_group_id_lock: + if self._current_state_group_id is None: + txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") + self._current_state_group_id = txn.fetchone()[0] + + self._current_state_group_id += 1 + return self._current_state_group_id + + +# Following functions taken from: https://github.com/coleifer/peewee + +def _parse_match_info(buf): + bufsize = len(buf) + return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)] + + +def _rank(raw_match_info): + """Handle match_info called w/default args 'pcx' - based on the example rank + function http://sqlite.org/fts3.html#appendix_a + """ + match_info = _parse_match_info(raw_match_info) + score = 0.0 + p, c = match_info[:2] + for phrase_num in range(p): + phrase_info_idx = 2 + (phrase_num * c * 3) + for col_num in range(c): + col_idx = phrase_info_idx + (col_num * 3) + x1, x2 = match_info[col_idx:col_idx + 2] + if x1 > 0: + score += float(x1) / x2 + return score diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py deleted file mode 100644 index 19949fc474..0000000000 --- a/synapse/storage/engines/sqlite3.py +++ /dev/null @@ -1,87 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import struct -import threading - -from synapse.storage.prepare_database import prepare_database - - -class Sqlite3Engine(object): - single_threaded = True - - def __init__(self, database_module, database_config): - self.module = database_module - - # The current max state_group, or None if we haven't looked - # in the DB yet. - self._current_state_group_id = None - self._current_state_group_id_lock = threading.Lock() - - def check_database(self, txn): - pass - - def convert_param_style(self, sql): - return sql - - def on_new_connection(self, db_conn): - prepare_database(db_conn, self, config=None) - db_conn.create_function("rank", 1, _rank) - - def is_deadlock(self, error): - return False - - def is_connection_closed(self, conn): - return False - - def lock_table(self, txn, table): - return - - def get_next_state_group_id(self, txn): - """Returns an int that can be used as a new state_group ID - """ - # We do application locking here since if we're using sqlite then - # we are a single process synapse. - with self._current_state_group_id_lock: - if self._current_state_group_id is None: - txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") - self._current_state_group_id = txn.fetchone()[0] - - self._current_state_group_id += 1 - return self._current_state_group_id - - -# Following functions taken from: https://github.com/coleifer/peewee - -def _parse_match_info(buf): - bufsize = len(buf) - return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)] - - -def _rank(raw_match_info): - """Handle match_info called w/default args 'pcx' - based on the example rank - function http://sqlite.org/fts3.html#appendix_a - """ - match_info = _parse_match_info(raw_match_info) - score = 0.0 - p, c = match_info[:2] - for phrase_num in range(p): - phrase_info_idx = 2 + (phrase_num * c * 3) - for col_num in range(c): - col_idx = phrase_info_idx + (col_num * 3) - x1, x2 = match_info[col_idx:col_idx + 2] - if x1 > 0: - score += float(x1) / x2 - return score diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 2743b52bad..134297e284 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -215,7 +215,7 @@ class PusherStore(PusherWorkerStore): with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so _simple_upsert will retry - newly_inserted = yield self._simple_upsert( + yield self._simple_upsert( table="pushers", keyvalues={ "app_id": app_id, @@ -238,7 +238,12 @@ class PusherStore(PusherWorkerStore): lock=False, ) - if newly_inserted: + user_has_pusher = self.get_if_user_has_pusher.cache.get( + (user_id,), None, update_metrics=False + ) + + if user_has_pusher is not True: + # invalidate, since we the user might not have had a pusher before yield self.runInteraction( "add_pusher", self._invalidate_cache_and_stream, diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index a8781b0e5d..ce48212265 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -168,14 +168,14 @@ class UserDirectoryStore(SQLBaseStore): if isinstance(self.database_engine, PostgresEngine): # We weight the localpart most highly, then display name and finally # server name - if new_entry: + if self.database_engine.can_native_upsert: sql = """ INSERT INTO user_directory_search(user_id, vector) VALUES (?, setweight(to_tsvector('english', ?), 'A') || setweight(to_tsvector('english', ?), 'D') || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - ) + ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector """ txn.execute( sql, @@ -185,20 +185,45 @@ class UserDirectoryStore(SQLBaseStore): ) ) else: - sql = """ - UPDATE user_directory_search - SET vector = setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - WHERE user_id = ? - """ - txn.execute( - sql, - ( - get_localpart_from_id(user_id), get_domain_from_id(user_id), - display_name, user_id, + # TODO: Remove this code after we've bumped the minimum version + # of postgres to always support upserts, so we can get rid of + # `new_entry` usage + if new_entry is True: + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) + """ + txn.execute( + sql, + ( + user_id, get_localpart_from_id(user_id), + get_domain_from_id(user_id), display_name, + ) + ) + elif new_entry is False: + sql = """ + UPDATE user_directory_search + SET vector = setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + WHERE user_id = ? + """ + txn.execute( + sql, + ( + get_localpart_from_id(user_id), + get_domain_from_id(user_id), + display_name, user_id, + ) + ) + else: + raise RuntimeError( + "upsert returned None when 'can_native_upsert' is False" ) - ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name,) if display_name else user_id self._simple_upsert_txn( diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 829f47d2e8..452d76ddd5 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -49,6 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.db_pool.runWithConnection = runWithConnection config = Mock() + config._enable_native_upserts = False config.event_cache_size = 1 config.database_config = {"name": "sqlite3"} hs = TestHomeServer( diff --git a/tests/test_server.py b/tests/test_server.py index 634a8fbca5..08fb3fe02f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -19,7 +19,7 @@ from six import StringIO from twisted.internet.defer import Deferred from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock +from twisted.test.proto_helpers import AccumulatingProtocol from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET @@ -30,12 +30,18 @@ from synapse.util import Clock from synapse.util.logcontext import make_deferred_yieldable from tests import unittest -from tests.server import FakeTransport, make_request, render, setup_test_homeserver +from tests.server import ( + FakeTransport, + ThreadedMemoryReactorClock, + make_request, + render, + setup_test_homeserver, +) class JsonResourceTests(unittest.TestCase): def setUp(self): - self.reactor = MemoryReactorClock() + self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor diff --git a/tests/unittest.py b/tests/unittest.py index 78d2f740f9..cda549c783 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -96,7 +96,7 @@ class TestCase(unittest.TestCase): method = getattr(self, methodName) - level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR)) + level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING)) @around(self) def setUp(orig): @@ -333,7 +333,15 @@ class HomeserverTestCase(TestCase): """ kwargs = dict(kwargs) kwargs.update(self._hs_args) - return setup_test_homeserver(self.addCleanup, *args, **kwargs) + hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) + stor = hs.get_datastore() + + # Run the database background updates. + if hasattr(stor, "do_next_background_update"): + while not self.get_success(stor.has_completed_background_updates()): + self.get_success(stor.do_next_background_update(1)) + + return hs def pump(self, by=0.0): """ diff --git a/tox.ini b/tox.ini index a0f5486829..9b2d78ed6d 100644 --- a/tox.ini +++ b/tox.ini @@ -149,4 +149,5 @@ deps = codecov commands = coverage combine + coverage xml codecov -X gcov -- cgit 1.4.1