diff options
author | Mark Haines <mjark@negativecurvature.net> | 2016-04-21 16:21:49 +0100 |
---|---|---|
committer | Mark Haines <mjark@negativecurvature.net> | 2016-04-21 16:21:49 +0100 |
commit | 712030aeef97c414d641a65b398355ed74dc7baf (patch) | |
tree | a355a2b8c9e66264a991dff3d41d3e19cdd90d5e | |
parent | Add an HTTP API for removing rejected pushers. (diff) | |
parent | pip install new python dependencies in jenkins.sh (diff) | |
download | synapse-712030aeef97c414d641a65b398355ed74dc7baf.tar.xz |
Merge branch 'develop' into markjh/split_pusher
-rwxr-xr-x | jenkins-postgres.sh | 2 | ||||
-rwxr-xr-x | jenkins-sqlite.sh | 2 | ||||
-rwxr-xr-x | jenkins.sh | 86 | ||||
-rw-r--r-- | synapse/federation/transport/client.py | 3 | ||||
-rw-r--r-- | synapse/handlers/auth.py | 23 | ||||
-rw-r--r-- | synapse/http/client.py | 5 | ||||
-rw-r--r-- | synapse/rest/media/v1/_base.py | 110 | ||||
-rw-r--r-- | synapse/rest/media/v1/base_resource.py | 460 | ||||
-rw-r--r-- | synapse/rest/media/v1/download_resource.py | 24 | ||||
-rw-r--r-- | synapse/rest/media/v1/media_repository.py | 395 | ||||
-rw-r--r-- | synapse/rest/media/v1/preview_url_resource.py | 55 | ||||
-rw-r--r-- | synapse/rest/media/v1/thumbnail_resource.py | 51 | ||||
-rw-r--r-- | synapse/rest/media/v1/upload_resource.py | 51 | ||||
-rw-r--r-- | synapse/state.py | 18 | ||||
-rw-r--r-- | synapse/storage/state.py | 19 | ||||
-rw-r--r-- | synapse/util/__init__.py | 3 | ||||
-rw-r--r-- | synapse/util/metrics.py | 23 | ||||
-rw-r--r-- | tests/replication/slave/storage/_base.py | 4 | ||||
-rw-r--r-- | tests/replication/slave/storage/test_events.py | 15 | ||||
-rw-r--r-- | tests/replication/slave/storage/test_receipts.py | 39 | ||||
-rw-r--r-- | tests/test_state.py | 4 |
21 files changed, 733 insertions, 659 deletions
diff --git a/jenkins-postgres.sh b/jenkins-postgres.sh index 9ac86d2593..ae6b111591 100755 --- a/jenkins-postgres.sh +++ b/jenkins-postgres.sh @@ -25,7 +25,9 @@ rm .coverage* || echo "No coverage files to remove" tox --notest -e py27 TOX_BIN=$WORKSPACE/.tox/py27/bin +python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install $TOX_BIN/pip install psycopg2 +$TOX_BIN/pip install lxml : ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"} diff --git a/jenkins-sqlite.sh b/jenkins-sqlite.sh index 345d01936c..9398d9db15 100755 --- a/jenkins-sqlite.sh +++ b/jenkins-sqlite.sh @@ -24,6 +24,8 @@ rm .coverage* || echo "No coverage files to remove" tox --notest -e py27 TOX_BIN=$WORKSPACE/.tox/py27/bin +python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install +$TOX_BIN/pip install lxml : ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"} diff --git a/jenkins.sh b/jenkins.sh deleted file mode 100755 index b826d510c9..0000000000 --- a/jenkins.sh +++ /dev/null @@ -1,86 +0,0 @@ -#!/bin/bash - -set -eux - -: ${WORKSPACE:="$(pwd)"} - -export PYTHONDONTWRITEBYTECODE=yep -export SYNAPSE_CACHE_FACTOR=1 - -# Output test results as junit xml -export TRIAL_FLAGS="--reporter=subunit" -export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml" -# Write coverage reports to a separate file for each process -export COVERAGE_OPTS="-p" -export DUMP_COVERAGE_COMMAND="coverage help" - -# Output flake8 violations to violations.flake8.log -# Don't exit with non-0 status code on Jenkins, -# so that the build steps continue and a later step can decided whether to -# UNSTABLE or FAILURE this build. -export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?" - -rm .coverage* || echo "No coverage files to remove" - -tox - -: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"} - -TOX_BIN=$WORKSPACE/.tox/py27/bin - -if [[ ! -e .sytest-base ]]; then - git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror -else - (cd .sytest-base; git fetch -p) -fi - -rm -rf sytest -git clone .sytest-base sytest --shared -cd sytest - -git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop) - -: ${PERL5LIB:=$WORKSPACE/perl5/lib/perl5} -: ${PERL_MB_OPT:=--install_base=$WORKSPACE/perl5} -: ${PERL_MM_OPT:=INSTALL_BASE=$WORKSPACE/perl5} -export PERL5LIB PERL_MB_OPT PERL_MM_OPT - -./install-deps.pl - -: ${PORT_BASE:=8000} - -echo >&2 "Running sytest with SQLite3"; -./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \ - --python $TOX_BIN/python --all --port-base $PORT_BASE > results-sqlite3.tap - -RUN_POSTGRES="" - -for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do - if psql synapse_jenkins_$port <<< ""; then - RUN_POSTGRES="$RUN_POSTGRES:$port" - cat > localhost-$port/database.yaml << EOF -name: psycopg2 -args: - database: synapse_jenkins_$port -EOF - fi -done - -# Run if both postgresql databases exist -if test "$RUN_POSTGRES" = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then - echo >&2 "Running sytest with PostgreSQL"; - $TOX_BIN/pip install psycopg2 - ./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \ - --python $TOX_BIN/python --all --port-base $PORT_BASE > results-postgresql.tap -else - echo >&2 "Skipping running sytest with PostgreSQL, $RUN_POSTGRES" -fi - -cd .. -cp sytest/.coverage.* . - -# Combine the coverage reports -echo "Combining:" .coverage.* -$TOX_BIN/python -m coverage combine -# Output coverage to coverage.xml -$TOX_BIN/coverage xml -o coverage.xml diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 2237e3413c..cd2841c4db 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -179,7 +179,8 @@ class TransportLayerClient(object): content = yield self.client.get_json( destination=destination, path=path, - retry_on_dns_fail=True, + retry_on_dns_fail=False, + timeout=20000, ) defer.returnValue(content) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7a13a8b11c..61fe56032a 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -428,24 +428,31 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def _check_password(self, user_id, password): - defer.returnValue( - not ( - (yield self._check_ldap_password(user_id, password)) - or - (yield self._check_local_password(user_id, password)) - )) + """ + Returns: + True if the user_id successfully authenticated + """ + valid_ldap = yield self._check_ldap_password(user_id, password) + if valid_ldap: + defer.returnValue(True) + + valid_local_password = yield self._check_local_password(user_id, password) + if valid_local_password: + defer.returnValue(True) + + defer.returnValue(False) @defer.inlineCallbacks def _check_local_password(self, user_id, password): try: user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - defer.returnValue(not self.validate_hash(password, password_hash)) + defer.returnValue(self.validate_hash(password, password_hash)) except LoginError: defer.returnValue(False) @defer.inlineCallbacks def _check_ldap_password(self, user_id, password): - if self.ldap_enabled is not True: + if not self.ldap_enabled: logger.debug("LDAP not configured") defer.returnValue(False) diff --git a/synapse/http/client.py b/synapse/http/client.py index 6c89b20984..902ae7a203 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -462,5 +462,8 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory): self._context = SSL.Context(SSL.SSLv23_METHOD) self._context.set_verify(VERIFY_NONE, lambda *_: None) - def getContext(self, hostname, port): + def getContext(self, hostname=None, port=None): return self._context + + def creatorForNetloc(self, hostname, port): + return self diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py new file mode 100644 index 0000000000..b9600f2167 --- /dev/null +++ b/synapse/rest/media/v1/_base.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import respond_with_json, finish_request +from synapse.api.errors import ( + cs_error, Codes, SynapseError +) + +from twisted.internet import defer +from twisted.protocols.basic import FileSender + +from synapse.util.stringutils import is_ascii + +import os + +import logging +import urllib +import urlparse + +logger = logging.getLogger(__name__) + + +def parse_media_id(request): + try: + # This allows users to append e.g. /test.png to the URL. Useful for + # clients that parse the URL to see content type. + server_name, media_id = request.postpath[:2] + file_name = None + if len(request.postpath) > 2: + try: + file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8") + except UnicodeDecodeError: + pass + return server_name, media_id, file_name + except: + raise SynapseError( + 404, + "Invalid media id token %r" % (request.postpath,), + Codes.UNKNOWN, + ) + + +def respond_404(request): + respond_with_json( + request, 404, + cs_error( + "Not found %r" % (request.postpath,), + code=Codes.NOT_FOUND, + ), + send_cors=True + ) + + +@defer.inlineCallbacks +def respond_with_file(request, media_type, file_path, + file_size=None, upload_name=None): + logger.debug("Responding with %r", file_path) + + if os.path.isfile(file_path): + request.setHeader(b"Content-Type", media_type.encode("UTF-8")) + if upload_name: + if is_ascii(upload_name): + request.setHeader( + b"Content-Disposition", + b"inline; filename=%s" % ( + urllib.quote(upload_name.encode("utf-8")), + ), + ) + else: + request.setHeader( + b"Content-Disposition", + b"inline; filename*=utf-8''%s" % ( + urllib.quote(upload_name.encode("utf-8")), + ), + ) + + # cache for at least a day. + # XXX: we might want to turn this off for data we don't want to + # recommend caching as it's sensitive or private - or at least + # select private. don't bother setting Expires as all our + # clients are smart enough to be happy with Cache-Control + request.setHeader( + b"Cache-Control", b"public,max-age=86400,s-maxage=86400" + ) + if file_size is None: + stat = os.stat(file_path) + file_size = stat.st_size + + request.setHeader( + b"Content-Length", b"%d" % (file_size,) + ) + + with open(file_path, "rb") as f: + yield FileSender().beginFileTransfer(f, request) + + finish_request(request) + else: + respond_404(request) diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py deleted file mode 100644 index 2b1938dc8e..0000000000 --- a/synapse/rest/media/v1/base_resource.py +++ /dev/null @@ -1,460 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .thumbnailer import Thumbnailer - -from synapse.http.matrixfederationclient import MatrixFederationHttpClient -from synapse.http.server import respond_with_json, finish_request -from synapse.util.stringutils import random_string -from synapse.api.errors import ( - cs_error, Codes, SynapseError -) - -from twisted.internet import defer, threads -from twisted.web.resource import Resource -from twisted.protocols.basic import FileSender - -from synapse.util.async import ObservableDeferred -from synapse.util.stringutils import is_ascii -from synapse.util.logcontext import preserve_context_over_fn - -import os - -import cgi -import logging -import urllib -import urlparse - -logger = logging.getLogger(__name__) - - -def parse_media_id(request): - try: - # This allows users to append e.g. /test.png to the URL. Useful for - # clients that parse the URL to see content type. - server_name, media_id = request.postpath[:2] - file_name = None - if len(request.postpath) > 2: - try: - file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8") - except UnicodeDecodeError: - pass - return server_name, media_id, file_name - except: - raise SynapseError( - 404, - "Invalid media id token %r" % (request.postpath,), - Codes.UNKNOWN, - ) - - -class BaseMediaResource(Resource): - isLeaf = True - - def __init__(self, hs, filepaths): - Resource.__init__(self) - self.auth = hs.get_auth() - self.client = MatrixFederationHttpClient(hs) - self.clock = hs.get_clock() - self.server_name = hs.hostname - self.store = hs.get_datastore() - self.max_upload_size = hs.config.max_upload_size - self.max_image_pixels = hs.config.max_image_pixels - self.max_spider_size = hs.config.max_spider_size - self.filepaths = filepaths - self.version_string = hs.version_string - self.downloads = {} - self.dynamic_thumbnails = hs.config.dynamic_thumbnails - self.thumbnail_requirements = hs.config.thumbnail_requirements - - def _respond_404(self, request): - respond_with_json( - request, 404, - cs_error( - "Not found %r" % (request.postpath,), - code=Codes.NOT_FOUND, - ), - send_cors=True - ) - - @staticmethod - def _makedirs(filepath): - dirname = os.path.dirname(filepath) - if not os.path.exists(dirname): - os.makedirs(dirname) - - def _get_remote_media(self, server_name, media_id): - key = (server_name, media_id) - download = self.downloads.get(key) - if download is None: - download = self._get_remote_media_impl(server_name, media_id) - download = ObservableDeferred( - download, - consumeErrors=True - ) - self.downloads[key] = download - - @download.addBoth - def callback(media_info): - del self.downloads[key] - return media_info - return download.observe() - - @defer.inlineCallbacks - def _get_remote_media_impl(self, server_name, media_id): - media_info = yield self.store.get_cached_remote_media( - server_name, media_id - ) - if not media_info: - media_info = yield self._download_remote_file( - server_name, media_id - ) - defer.returnValue(media_info) - - @defer.inlineCallbacks - def _download_remote_file(self, server_name, media_id): - file_id = random_string(24) - - fname = self.filepaths.remote_media_filepath( - server_name, file_id - ) - self._makedirs(fname) - - try: - with open(fname, "wb") as f: - request_path = "/".join(( - "/_matrix/media/v1/download", server_name, media_id, - )) - length, headers = yield self.client.get_file( - server_name, request_path, output_stream=f, - max_size=self.max_upload_size, - ) - media_type = headers["Content-Type"][0] - time_now_ms = self.clock.time_msec() - - content_disposition = headers.get("Content-Disposition", None) - if content_disposition: - _, params = cgi.parse_header(content_disposition[0],) - upload_name = None - - # First check if there is a valid UTF-8 filename - upload_name_utf8 = params.get("filename*", None) - if upload_name_utf8: - if upload_name_utf8.lower().startswith("utf-8''"): - upload_name = upload_name_utf8[7:] - - # If there isn't check for an ascii name. - if not upload_name: - upload_name_ascii = params.get("filename", None) - if upload_name_ascii and is_ascii(upload_name_ascii): - upload_name = upload_name_ascii - - if upload_name: - upload_name = urlparse.unquote(upload_name) - try: - upload_name = upload_name.decode("utf-8") - except UnicodeDecodeError: - upload_name = None - else: - upload_name = None - - yield self.store.store_cached_remote_media( - origin=server_name, - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=length, - filesystem_id=file_id, - ) - except: - os.remove(fname) - raise - - media_info = { - "media_type": media_type, - "media_length": length, - "upload_name": upload_name, - "created_ts": time_now_ms, - "filesystem_id": file_id, - } - - yield self._generate_remote_thumbnails( - server_name, media_id, media_info - ) - - defer.returnValue(media_info) - - @defer.inlineCallbacks - def _respond_with_file(self, request, media_type, file_path, - file_size=None, upload_name=None): - logger.debug("Responding with %r", file_path) - - if os.path.isfile(file_path): - request.setHeader(b"Content-Type", media_type.encode("UTF-8")) - if upload_name: - if is_ascii(upload_name): - request.setHeader( - b"Content-Disposition", - b"inline; filename=%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) - else: - request.setHeader( - b"Content-Disposition", - b"inline; filename*=utf-8''%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) - - # cache for at least a day. - # XXX: we might want to turn this off for data we don't want to - # recommend caching as it's sensitive or private - or at least - # select private. don't bother setting Expires as all our - # clients are smart enough to be happy with Cache-Control - request.setHeader( - b"Cache-Control", b"public,max-age=86400,s-maxage=86400" - ) - if file_size is None: - stat = os.stat(file_path) - file_size = stat.st_size - - request.setHeader( - b"Content-Length", b"%d" % (file_size,) - ) - - with open(file_path, "rb") as f: - yield FileSender().beginFileTransfer(f, request) - - finish_request(request) - else: - self._respond_404(request) - - def _get_thumbnail_requirements(self, media_type): - return self.thumbnail_requirements.get(media_type, ()) - - def _generate_thumbnail(self, input_path, t_path, t_width, t_height, - t_method, t_type): - thumbnailer = Thumbnailer(input_path) - m_width = thumbnailer.width - m_height = thumbnailer.height - - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels - ) - return - - if t_method == "crop": - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) - elif t_method == "scale": - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) - else: - t_len = None - - return t_len - - @defer.inlineCallbacks - def _generate_local_exact_thumbnail(self, media_id, t_width, t_height, - t_method, t_type): - input_path = self.filepaths.local_media_filepath(media_id) - - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - - t_len = yield preserve_context_over_fn( - threads.deferToThread, - self._generate_thumbnail, - input_path, t_path, t_width, t_height, t_method, t_type - ) - - if t_len: - yield self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len - ) - - defer.returnValue(t_path) - - @defer.inlineCallbacks - def _generate_remote_exact_thumbnail(self, server_name, file_id, media_id, - t_width, t_height, t_method, t_type): - input_path = self.filepaths.remote_media_filepath(server_name, file_id) - - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - - t_len = yield preserve_context_over_fn( - threads.deferToThread, - self._generate_thumbnail, - input_path, t_path, t_width, t_height, t_method, t_type - ) - - if t_len: - yield self.store.store_remote_media_thumbnail( - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len - ) - - defer.returnValue(t_path) - - @defer.inlineCallbacks - def _generate_local_thumbnails(self, media_id, media_info): - media_type = media_info["media_type"] - requirements = self._get_thumbnail_requirements(media_type) - if not requirements: - return - - input_path = self.filepaths.local_media_filepath(media_id) - thumbnailer = Thumbnailer(input_path) - m_width = thumbnailer.width - m_height = thumbnailer.height - - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels - ) - return - - local_thumbnails = [] - - def generate_thumbnails(): - scales = set() - crops = set() - for r_width, r_height, r_method, r_type in requirements: - if r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) - scales.add(( - min(m_width, t_width), min(m_height, t_height), r_type, - )) - elif r_method == "crop": - crops.add((r_width, r_height, r_type)) - - for t_width, t_height, t_type in scales: - t_method = "scale" - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) - - local_thumbnails.append(( - media_id, t_width, t_height, t_type, t_method, t_len - )) - - for t_width, t_height, t_type in crops: - if (t_width, t_height, t_type) in scales: - # If the aspect ratio of the cropped thumbnail matches a purely - # scaled one then there is no point in calculating a separate - # thumbnail. - continue - t_method = "crop" - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) - local_thumbnails.append(( - media_id, t_width, t_height, t_type, t_method, t_len - )) - - yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) - - for l in local_thumbnails: - yield self.store.store_local_thumbnail(*l) - - defer.returnValue({ - "width": m_width, - "height": m_height, - }) - - @defer.inlineCallbacks - def _generate_remote_thumbnails(self, server_name, media_id, media_info): - media_type = media_info["media_type"] - file_id = media_info["filesystem_id"] - requirements = self._get_thumbnail_requirements(media_type) - if not requirements: - return - - remote_thumbnails = [] - - input_path = self.filepaths.remote_media_filepath(server_name, file_id) - thumbnailer = Thumbnailer(input_path) - m_width = thumbnailer.width - m_height = thumbnailer.height - - def generate_thumbnails(): - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels - ) - return - - scales = set() - crops = set() - for r_width, r_height, r_method, r_type in requirements: - if r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) - scales.add(( - min(m_width, t_width), min(m_height, t_height), r_type, - )) - elif r_method == "crop": - crops.add((r_width, r_height, r_type)) - - for t_width, t_height, t_type in scales: - t_method = "scale" - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) - remote_thumbnails.append([ - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len - ]) - - for t_width, t_height, t_type in crops: - if (t_width, t_height, t_type) in scales: - # If the aspect ratio of the cropped thumbnail matches a purely - # scaled one then there is no point in calculating a separate - # thumbnail. - continue - t_method = "crop" - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) - remote_thumbnails.append([ - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len - ]) - - yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) - - for r in remote_thumbnails: - yield self.store.store_remote_media_thumbnail(*r) - - defer.returnValue({ - "width": m_width, - "height": m_height, - }) diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 1aad6b3551..510884262c 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base_resource import BaseMediaResource, parse_media_id +from ._base import parse_media_id, respond_with_file, respond_404 +from twisted.web.resource import Resource from synapse.http.server import request_handler from twisted.web.server import NOT_DONE_YET @@ -24,7 +25,18 @@ import logging logger = logging.getLogger(__name__) -class DownloadResource(BaseMediaResource): +class DownloadResource(Resource): + isLeaf = True + + def __init__(self, hs, media_repo): + Resource.__init__(self) + + self.filepaths = media_repo.filepaths + self.media_repo = media_repo + self.server_name = hs.hostname + self.store = hs.get_datastore() + self.version_string = hs.version_string + def render_GET(self, request): self._async_render_GET(request) return NOT_DONE_YET @@ -44,7 +56,7 @@ class DownloadResource(BaseMediaResource): def _respond_local_file(self, request, media_id, name): media_info = yield self.store.get_local_media(media_id) if not media_info: - self._respond_404(request) + respond_404(request) return media_type = media_info["media_type"] @@ -52,14 +64,14 @@ class DownloadResource(BaseMediaResource): upload_name = name if name else media_info["upload_name"] file_path = self.filepaths.local_media_filepath(media_id) - yield self._respond_with_file( + yield respond_with_file( request, media_type, file_path, media_length, upload_name=upload_name, ) @defer.inlineCallbacks def _respond_remote_file(self, request, server_name, media_id, name): - media_info = yield self._get_remote_media(server_name, media_id) + media_info = yield self.media_repo.get_remote_media(server_name, media_id) media_type = media_info["media_type"] media_length = media_info["media_length"] @@ -70,7 +82,7 @@ class DownloadResource(BaseMediaResource): server_name, filesystem_id ) - yield self._respond_with_file( + yield respond_with_file( request, media_type, file_path, media_length, upload_name=upload_name, ) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 77fb0313c5..d96bf9afe2 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -22,11 +22,395 @@ from .filepath import MediaFilePaths from twisted.web.resource import Resource +from .thumbnailer import Thumbnailer + +from synapse.http.matrixfederationclient import MatrixFederationHttpClient +from synapse.util.stringutils import random_string + +from twisted.internet import defer, threads + +from synapse.util.async import ObservableDeferred +from synapse.util.stringutils import is_ascii +from synapse.util.logcontext import preserve_context_over_fn + +import os + +import cgi import logging +import urlparse logger = logging.getLogger(__name__) +class MediaRepository(object): + def __init__(self, hs, filepaths): + self.auth = hs.get_auth() + self.client = MatrixFederationHttpClient(hs) + self.clock = hs.get_clock() + self.server_name = hs.hostname + self.store = hs.get_datastore() + self.max_upload_size = hs.config.max_upload_size + self.max_image_pixels = hs.config.max_image_pixels + self.filepaths = filepaths + self.downloads = {} + self.dynamic_thumbnails = hs.config.dynamic_thumbnails + self.thumbnail_requirements = hs.config.thumbnail_requirements + + @staticmethod + def _makedirs(filepath): + dirname = os.path.dirname(filepath) + if not os.path.exists(dirname): + os.makedirs(dirname) + + @defer.inlineCallbacks + def create_content(self, media_type, upload_name, content, content_length, + auth_user): + media_id = random_string(24) + + fname = self.filepaths.local_media_filepath(media_id) + self._makedirs(fname) + + # This shouldn't block for very long because the content will have + # already been uploaded at this point. + with open(fname, "wb") as f: + f.write(content) + + yield self.store.store_local_media( + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=content_length, + user_id=auth_user, + ) + media_info = { + "media_type": media_type, + "media_length": content_length, + } + + yield self._generate_local_thumbnails(media_id, media_info) + + defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) + + def get_remote_media(self, server_name, media_id): + key = (server_name, media_id) + download = self.downloads.get(key) + if download is None: + download = self._get_remote_media_impl(server_name, media_id) + download = ObservableDeferred( + download, + consumeErrors=True + ) + self.downloads[key] = download + + @download.addBoth + def callback(media_info): + del self.downloads[key] + return media_info + return download.observe() + + @defer.inlineCallbacks + def _get_remote_media_impl(self, server_name, media_id): + media_info = yield self.store.get_cached_remote_media( + server_name, media_id + ) + if not media_info: + media_info = yield self._download_remote_file( + server_name, media_id + ) + defer.returnValue(media_info) + + @defer.inlineCallbacks + def _download_remote_file(self, server_name, media_id): + file_id = random_string(24) + + fname = self.filepaths.remote_media_filepath( + server_name, file_id + ) + self._makedirs(fname) + + try: + with open(fname, "wb") as f: + request_path = "/".join(( + "/_matrix/media/v1/download", server_name, media_id, + )) + length, headers = yield self.client.get_file( + server_name, request_path, output_stream=f, + max_size=self.max_upload_size, + ) + media_type = headers["Content-Type"][0] + time_now_ms = self.clock.time_msec() + + content_disposition = headers.get("Content-Disposition", None) + if content_disposition: + _, params = cgi.parse_header(content_disposition[0],) + upload_name = None + + # First check if there is a valid UTF-8 filename + upload_name_utf8 = params.get("filename*", None) + if upload_name_utf8: + if upload_name_utf8.lower().startswith("utf-8''"): + upload_name = upload_name_utf8[7:] + + # If there isn't check for an ascii name. + if not upload_name: + upload_name_ascii = params.get("filename", None) + if upload_name_ascii and is_ascii(upload_name_ascii): + upload_name = upload_name_ascii + + if upload_name: + upload_name = urlparse.unquote(upload_name) + try: + upload_name = upload_name.decode("utf-8") + except UnicodeDecodeError: + upload_name = None + else: + upload_name = None + + yield self.store.store_cached_remote_media( + origin=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=length, + filesystem_id=file_id, + ) + except: + os.remove(fname) + raise + + media_info = { + "media_type": media_type, + "media_length": length, + "upload_name": upload_name, + "created_ts": time_now_ms, + "filesystem_id": file_id, + } + + yield self._generate_remote_thumbnails( + server_name, media_id, media_info + ) + + defer.returnValue(media_info) + + def _get_thumbnail_requirements(self, media_type): + return self.thumbnail_requirements.get(media_type, ()) + + def _generate_thumbnail(self, input_path, t_path, t_width, t_height, + t_method, t_type): + thumbnailer = Thumbnailer(input_path) + m_width = thumbnailer.width + m_height = thumbnailer.height + + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, m_height, self.max_image_pixels + ) + return + + if t_method == "crop": + t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) + elif t_method == "scale": + t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) + else: + t_len = None + + return t_len + + @defer.inlineCallbacks + def generate_local_exact_thumbnail(self, media_id, t_width, t_height, + t_method, t_type): + input_path = self.filepaths.local_media_filepath(media_id) + + t_path = self.filepaths.local_media_thumbnail( + media_id, t_width, t_height, t_type, t_method + ) + self._makedirs(t_path) + + t_len = yield preserve_context_over_fn( + threads.deferToThread, + self._generate_thumbnail, + input_path, t_path, t_width, t_height, t_method, t_type + ) + + if t_len: + yield self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) + + defer.returnValue(t_path) + + @defer.inlineCallbacks + def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, + t_width, t_height, t_method, t_type): + input_path = self.filepaths.remote_media_filepath(server_name, file_id) + + t_path = self.filepaths.remote_media_thumbnail( + server_name, file_id, t_width, t_height, t_type, t_method + ) + self._makedirs(t_path) + + t_len = yield preserve_context_over_fn( + threads.deferToThread, + self._generate_thumbnail, + input_path, t_path, t_width, t_height, t_method, t_type + ) + + if t_len: + yield self.store.store_remote_media_thumbnail( + server_name, media_id, file_id, + t_width, t_height, t_type, t_method, t_len + ) + + defer.returnValue(t_path) + + @defer.inlineCallbacks + def _generate_local_thumbnails(self, media_id, media_info): + media_type = media_info["media_type"] + requirements = self._get_thumbnail_requirements(media_type) + if not requirements: + return + + input_path = self.filepaths.local_media_filepath(media_id) + thumbnailer = Thumbnailer(input_path) + m_width = thumbnailer.width + m_height = thumbnailer.height + + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, m_height, self.max_image_pixels + ) + return + + local_thumbnails = [] + + def generate_thumbnails(): + scales = set() + crops = set() + for r_width, r_height, r_method, r_type in requirements: + if r_method == "scale": + t_width, t_height = thumbnailer.aspect(r_width, r_height) + scales.add(( + min(m_width, t_width), min(m_height, t_height), r_type, + )) + elif r_method == "crop": + crops.add((r_width, r_height, r_type)) + + for t_width, t_height, t_type in scales: + t_method = "scale" + t_path = self.filepaths.local_media_thumbnail( + media_id, t_width, t_height, t_type, t_method + ) + self._makedirs(t_path) + t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) + + local_thumbnails.append(( + media_id, t_width, t_height, t_type, t_method, t_len + )) + + for t_width, t_height, t_type in crops: + if (t_width, t_height, t_type) in scales: + # If the aspect ratio of the cropped thumbnail matches a purely + # scaled one then there is no point in calculating a separate + # thumbnail. + continue + t_method = "crop" + t_path = self.filepaths.local_media_thumbnail( + media_id, t_width, t_height, t_type, t_method + ) + self._makedirs(t_path) + t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) + local_thumbnails.append(( + media_id, t_width, t_height, t_type, t_method, t_len + )) + + yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) + + for l in local_thumbnails: + yield self.store.store_local_thumbnail(*l) + + defer.returnValue({ + "width": m_width, + "height": m_height, + }) + + @defer.inlineCallbacks + def _generate_remote_thumbnails(self, server_name, media_id, media_info): + media_type = media_info["media_type"] + file_id = media_info["filesystem_id"] + requirements = self._get_thumbnail_requirements(media_type) + if not requirements: + return + + remote_thumbnails = [] + + input_path = self.filepaths.remote_media_filepath(server_name, file_id) + thumbnailer = Thumbnailer(input_path) + m_width = thumbnailer.width + m_height = thumbnailer.height + + def generate_thumbnails(): + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, m_height, self.max_image_pixels + ) + return + + scales = set() + crops = set() + for r_width, r_height, r_method, r_type in requirements: + if r_method == "scale": + t_width, t_height = thumbnailer.aspect(r_width, r_height) + scales.add(( + min(m_width, t_width), min(m_height, t_height), r_type, + )) + elif r_method == "crop": + crops.add((r_width, r_height, r_type)) + + for t_width, t_height, t_type in scales: + t_method = "scale" + t_path = self.filepaths.remote_media_thumbnail( + server_name, file_id, t_width, t_height, t_type, t_method + ) + self._makedirs(t_path) + t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) + remote_thumbnails.append([ + server_name, media_id, file_id, + t_width, t_height, t_type, t_method, t_len + ]) + + for t_width, t_height, t_type in crops: + if (t_width, t_height, t_type) in scales: + # If the aspect ratio of the cropped thumbnail matches a purely + # scaled one then there is no point in calculating a separate + # thumbnail. + continue + t_method = "crop" + t_path = self.filepaths.remote_media_thumbnail( + server_name, file_id, t_width, t_height, t_type, t_method + ) + self._makedirs(t_path) + t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) + remote_thumbnails.append([ + server_name, media_id, file_id, + t_width, t_height, t_type, t_method, t_len + ]) + + yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) + + for r in remote_thumbnails: + yield self.store.store_remote_media_thumbnail(*r) + + defer.returnValue({ + "width": m_width, + "height": m_height, + }) + + class MediaRepositoryResource(Resource): """File uploading and downloading. @@ -75,9 +459,12 @@ class MediaRepositoryResource(Resource): def __init__(self, hs): Resource.__init__(self) filepaths = MediaFilePaths(hs.config.media_store_path) - self.putChild("upload", UploadResource(hs, filepaths)) - self.putChild("download", DownloadResource(hs, filepaths)) - self.putChild("thumbnail", ThumbnailResource(hs, filepaths)) + + media_repo = MediaRepository(hs, filepaths) + + self.putChild("upload", UploadResource(hs, media_repo)) + self.putChild("download", DownloadResource(hs, media_repo)) + self.putChild("thumbnail", ThumbnailResource(hs, media_repo)) self.putChild("identicon", IdenticonResource()) if hs.config.url_preview_enabled: - self.putChild("preview_url", PreviewUrlResource(hs, filepaths)) + self.putChild("preview_url", PreviewUrlResource(hs, media_repo)) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index c27ba72735..69327ac493 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -13,10 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base_resource import BaseMediaResource - from twisted.web.server import NOT_DONE_YET from twisted.internet import defer +from twisted.web.resource import Resource from synapse.api.errors import ( SynapseError, Codes, @@ -41,12 +40,22 @@ import logging logger = logging.getLogger(__name__) -class PreviewUrlResource(BaseMediaResource): +class PreviewUrlResource(Resource): isLeaf = True - def __init__(self, hs, filepaths): - BaseMediaResource.__init__(self, hs, filepaths) + def __init__(self, hs, media_repo): + Resource.__init__(self) + + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.version_string = hs.version_string + self.filepaths = media_repo.filepaths + self.max_spider_size = hs.config.max_spider_size + self.server_name = hs.hostname + self.store = hs.get_datastore() self.client = SpiderHttpClient(hs) + self.media_repo = media_repo + if hasattr(hs.config, "url_preview_url_blacklist"): self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist @@ -156,7 +165,7 @@ class PreviewUrlResource(BaseMediaResource): logger.debug("got media_info of '%s'" % media_info) if self._is_media(media_info['media_type']): - dims = yield self._generate_local_thumbnails( + dims = yield self.media_repo._generate_local_thumbnails( media_info['filesystem_id'], media_info ) @@ -179,23 +188,27 @@ class PreviewUrlResource(BaseMediaResource): elif self._is_html(media_info['media_type']): # TODO: somehow stop a big HTML tree from exploding synapse's RAM - from lxml import html + from lxml import etree + + file = open(media_info['filename']) + body = file.read() + file.close() + + # clobber the encoding from the content-type, or default to utf-8 + # XXX: this overrides any <meta/> or XML charset headers in the body + # which may pose problems, but so far seems to work okay. + match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I) + encoding = match.group(1) if match else "utf-8" try: - tree = html.parse(media_info['filename']) + parser = etree.HTMLParser(recover=True, encoding=encoding) + tree = etree.fromstring(body, parser) og = yield self._calc_og(tree, media_info, requester) except UnicodeDecodeError: - # XXX: evil evil bodge - # Empirically, sites like google.com mix Latin-1 and utf-8 - # encodings in the same page. The rogue Latin-1 characters - # cause lxml to choke with a UnicodeDecodeError, so if we - # see this we go and do a manual decode of the HTML before - # handing it to lxml as utf-8 encoding, counter-intuitively, - # which seems to make it happier... - file = open(media_info['filename']) - body = file.read() - file.close() - tree = html.fromstring(body.decode('utf-8', 'ignore')) + # blindly try decoding the body as utf-8, which seems to fix + # the charset mismatches on https://google.com + parser = etree.HTMLParser(recover=True, encoding=encoding) + tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser) og = yield self._calc_og(tree, media_info, requester) else: @@ -287,7 +300,7 @@ class PreviewUrlResource(BaseMediaResource): if self._is_media(image_info['media_type']): # TODO: make sure we don't choke on white-on-transparent images - dims = yield self._generate_local_thumbnails( + dims = yield self.media_repo._generate_local_thumbnails( image_info['filesystem_id'], image_info ) if dims: @@ -358,7 +371,7 @@ class PreviewUrlResource(BaseMediaResource): file_id = random_string(24) fname = self.filepaths.local_media_filepath(file_id) - self._makedirs(fname) + self.media_repo._makedirs(fname) try: with open(fname, "wb") as f: diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 40ef22459c..234dd4261c 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -14,7 +14,8 @@ # limitations under the License. -from .base_resource import BaseMediaResource, parse_media_id +from ._base import parse_media_id, respond_404, respond_with_file +from twisted.web.resource import Resource from synapse.http.servlet import parse_string, parse_integer from synapse.http.server import request_handler @@ -26,9 +27,19 @@ import logging logger = logging.getLogger(__name__) -class ThumbnailResource(BaseMediaResource): +class ThumbnailResource(Resource): isLeaf = True + def __init__(self, hs, media_repo): + Resource.__init__(self) + + self.store = hs.get_datastore() + self.filepaths = media_repo.filepaths + self.media_repo = media_repo + self.dynamic_thumbnails = hs.config.dynamic_thumbnails + self.server_name = hs.hostname + self.version_string = hs.version_string + def render_GET(self, request): self._async_render_GET(request) return NOT_DONE_YET @@ -69,12 +80,12 @@ class ThumbnailResource(BaseMediaResource): media_info = yield self.store.get_local_media(media_id) if not media_info: - self._respond_404(request) + respond_404(request) return # if media_info["media_type"] == "image/svg+xml": # file_path = self.filepaths.local_media_filepath(media_id) - # yield self._respond_with_file(request, media_info["media_type"], file_path) + # yield respond_with_file(request, media_info["media_type"], file_path) # return thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) @@ -91,7 +102,7 @@ class ThumbnailResource(BaseMediaResource): file_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method, ) - yield self._respond_with_file(request, t_type, file_path) + yield respond_with_file(request, t_type, file_path) else: yield self._respond_default_thumbnail( @@ -105,12 +116,12 @@ class ThumbnailResource(BaseMediaResource): media_info = yield self.store.get_local_media(media_id) if not media_info: - self._respond_404(request) + respond_404(request) return # if media_info["media_type"] == "image/svg+xml": # file_path = self.filepaths.local_media_filepath(media_id) - # yield self._respond_with_file(request, media_info["media_type"], file_path) + # yield respond_with_file(request, media_info["media_type"], file_path) # return thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) @@ -124,18 +135,18 @@ class ThumbnailResource(BaseMediaResource): file_path = self.filepaths.local_media_thumbnail( media_id, desired_width, desired_height, desired_type, desired_method, ) - yield self._respond_with_file(request, desired_type, file_path) + yield respond_with_file(request, desired_type, file_path) return logger.debug("We don't have a local thumbnail of that size. Generating") # Okay, so we generate one. - file_path = yield self._generate_local_exact_thumbnail( + file_path = yield self.media_repo.generate_local_exact_thumbnail( media_id, desired_width, desired_height, desired_method, desired_type ) if file_path: - yield self._respond_with_file(request, desired_type, file_path) + yield respond_with_file(request, desired_type, file_path) else: yield self._respond_default_thumbnail( request, media_info, desired_width, desired_height, @@ -146,11 +157,11 @@ class ThumbnailResource(BaseMediaResource): def _select_or_generate_remote_thumbnail(self, request, server_name, media_id, desired_width, desired_height, desired_method, desired_type): - media_info = yield self._get_remote_media(server_name, media_id) + media_info = yield self.media_repo.get_remote_media(server_name, media_id) # if media_info["media_type"] == "image/svg+xml": # file_path = self.filepaths.remote_media_filepath(server_name, media_id) - # yield self._respond_with_file(request, media_info["media_type"], file_path) + # yield respond_with_file(request, media_info["media_type"], file_path) # return thumbnail_infos = yield self.store.get_remote_media_thumbnails( @@ -170,19 +181,19 @@ class ThumbnailResource(BaseMediaResource): server_name, file_id, desired_width, desired_height, desired_type, desired_method, ) - yield self._respond_with_file(request, desired_type, file_path) + yield respond_with_file(request, desired_type, file_path) return logger.debug("We don't have a local thumbnail of that size. Generating") # Okay, so we generate one. - file_path = yield self._generate_remote_exact_thumbnail( + file_path = yield self.media_repo.generate_remote_exact_thumbnail( server_name, file_id, media_id, desired_width, desired_height, desired_method, desired_type ) if file_path: - yield self._respond_with_file(request, desired_type, file_path) + yield respond_with_file(request, desired_type, file_path) else: yield self._respond_default_thumbnail( request, media_info, desired_width, desired_height, @@ -194,11 +205,11 @@ class ThumbnailResource(BaseMediaResource): height, method, m_type): # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead. - media_info = yield self._get_remote_media(server_name, media_id) + media_info = yield self.media_repo.get_remote_media(server_name, media_id) # if media_info["media_type"] == "image/svg+xml": # file_path = self.filepaths.remote_media_filepath(server_name, media_id) - # yield self._respond_with_file(request, media_info["media_type"], file_path) + # yield respond_with_file(request, media_info["media_type"], file_path) # return thumbnail_infos = yield self.store.get_remote_media_thumbnails( @@ -219,7 +230,7 @@ class ThumbnailResource(BaseMediaResource): file_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method, ) - yield self._respond_with_file(request, t_type, file_path, t_length) + yield respond_with_file(request, t_type, file_path, t_length) else: yield self._respond_default_thumbnail( request, media_info, width, height, method, m_type, @@ -245,7 +256,7 @@ class ThumbnailResource(BaseMediaResource): "_default", "_default", ) if not thumbnail_infos: - self._respond_404(request) + respond_404(request) return thumbnail_info = self._select_thumbnail( @@ -261,7 +272,7 @@ class ThumbnailResource(BaseMediaResource): file_path = self.filepaths.default_thumbnail( top_level_type, sub_type, t_width, t_height, t_type, t_method, ) - yield self.respond_with_file(request, t_type, file_path, t_length) + yield respond_with_file(request, t_type, file_path, t_length) def _select_thumbnail(self, desired_width, desired_height, desired_method, desired_type, thumbnail_infos): diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 9c7ad4ae85..299e1f6e56 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -15,20 +15,33 @@ from synapse.http.server import respond_with_json, request_handler -from synapse.util.stringutils import random_string from synapse.api.errors import SynapseError from twisted.web.server import NOT_DONE_YET from twisted.internet import defer -from .base_resource import BaseMediaResource +from twisted.web.resource import Resource import logging logger = logging.getLogger(__name__) -class UploadResource(BaseMediaResource): +class UploadResource(Resource): + isLeaf = True + + def __init__(self, hs, media_repo): + Resource.__init__(self) + + self.media_repo = media_repo + self.filepaths = media_repo.filepaths + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.server_name = hs.hostname + self.auth = hs.get_auth() + self.max_upload_size = hs.config.max_upload_size + self.version_string = hs.version_string + def render_POST(self, request): self._async_render_POST(request) return NOT_DONE_YET @@ -37,36 +50,6 @@ class UploadResource(BaseMediaResource): respond_with_json(request, 200, {}, send_cors=True) return NOT_DONE_YET - @defer.inlineCallbacks - def create_content(self, media_type, upload_name, content, content_length, - auth_user): - media_id = random_string(24) - - fname = self.filepaths.local_media_filepath(media_id) - self._makedirs(fname) - - # This shouldn't block for very long because the content will have - # already been uploaded at this point. - with open(fname, "wb") as f: - f.write(content) - - yield self.store.store_local_media( - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=content_length, - user_id=auth_user, - ) - media_info = { - "media_type": media_type, - "media_length": content_length, - } - - yield self._generate_local_thumbnails(media_id, media_info) - - defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) - @request_handler @defer.inlineCallbacks def _async_render_POST(self, request): @@ -108,7 +91,7 @@ class UploadResource(BaseMediaResource): # disposition = headers.getRawHeaders("Content-Disposition")[0] # TODO(markjh): parse content-dispostion - content_uri = yield self.create_content( + content_uri = yield self.media_repo.create_content( media_type, upload_name, request.content.read(), content_length, requester.user ) diff --git a/synapse/state.py b/synapse/state.py index 58211f5feb..d0f76dc4f5 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -214,7 +214,7 @@ class StateHandler(object): if self._state_cache is not None: cache = self._state_cache.get(group_names, None) - if cache and cache.state_group: + if cache: cache.ts = self.clock.time_msec() event_dict = yield self.store.get_events(cache.state.values()) @@ -230,22 +230,34 @@ class StateHandler(object): (cache.state_group, state, prev_states) ) + logger.info("Resolving state for %s with %d groups", room_id, len(state_groups)) + new_state, prev_states = self._resolve_events( state_groups.values(), event_type, state_key ) + state_group = None + new_state_event_ids = frozenset(e.event_id for e in new_state.values()) + for sg, events in state_groups.items(): + if new_state_event_ids == frozenset(e.event_id for e in events): + state_group = sg + break + if self._state_cache is not None: cache = _StateCacheEntry( state={key: event.event_id for key, event in new_state.items()}, - state_group=None, + state_group=state_group, ts=self.clock.time_msec() ) self._state_cache[group_names] = cache - defer.returnValue((None, new_state, prev_states)) + defer.returnValue((state_group, new_state, prev_states)) def resolve_events(self, state_sets, event): + logger.info( + "Resolving state for %s with %d groups", event.room_id, len(state_sets) + ) if event.is_state(): return self._resolve_events( state_sets, event.type, event.state_key diff --git a/synapse/storage/state.py b/synapse/storage/state.py index c5d2a3a6df..5b743db67a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -174,6 +174,12 @@ class StateStore(SQLBaseStore): return [r[0] for r in results] return self.runInteraction("get_current_state_for_key", f) + @cached(num_args=2, lru=True, max_entries=1000) + def _get_state_group_from_group(self, group, types): + raise NotImplementedError() + + @cachedList(cached_method_name="_get_state_group_from_group", + list_name="groups", num_args=2, inlineCallbacks=True) def _get_state_groups_from_groups(self, groups, types): """Returns dictionary state_group -> (dict of (type, state_key) -> event id) """ @@ -201,18 +207,23 @@ class StateStore(SQLBaseStore): txn.execute(sql, args) rows = self.cursor_to_dict(txn) - results = {} + results = {group: {} for group in groups} for row in rows: key = (row["type"], row["state_key"]) - results.setdefault(row["state_group"], {})[key] = row["event_id"] + results[row["state_group"]][key] = row["event_id"] return results + results = {} + chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)] for chunk in chunks: - return self.runInteraction( + res = yield self.runInteraction( "_get_state_groups_from_groups", f, chunk ) + results.update(res) + + defer.returnValue(results) @defer.inlineCallbacks def get_state_for_events(self, event_ids, types): @@ -359,6 +370,8 @@ class StateStore(SQLBaseStore): a `state_key` of None matches all state_keys. If `types` is None then all events are returned. """ + if types: + types = frozenset(types) results = {} missing_groups = [] if types is not None: diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index b462495eb8..2b3f0bef3c 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.errors import SynapseError from synapse.util.logcontext import PreserveLoggingContext from twisted.internet import defer, reactor, task @@ -80,7 +81,7 @@ class Clock(object): def timed_out_fn(): try: - ret_deferred.errback(RuntimeError("Timed out")) + ret_deferred.errback(SynapseError(504, "Timed out")) except: pass diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index c51b641125..e1f374807e 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -50,7 +50,7 @@ block_db_txn_duration = metrics.register_distribution( class Measure(object): __slots__ = [ "clock", "name", "start_context", "start", "new_context", "ru_utime", - "ru_stime", "db_txn_count", "db_txn_duration" + "ru_stime", "db_txn_count", "db_txn_duration", "created_context" ] def __init__(self, clock, name): @@ -58,14 +58,20 @@ class Measure(object): self.name = name self.start_context = None self.start = None + self.created_context = False def __enter__(self): self.start = self.clock.time_msec() self.start_context = LoggingContext.current_context() - if self.start_context: - self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() - self.db_txn_count = self.start_context.db_txn_count - self.db_txn_duration = self.start_context.db_txn_duration + if not self.start_context: + logger.warn("Entered Measure without log context: %s", self.name) + self.start_context = LoggingContext("Measure") + self.start_context.__enter__() + self.created_context = True + + self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() + self.db_txn_count = self.start_context.db_txn_count + self.db_txn_duration = self.start_context.db_txn_duration def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None or not self.start_context: @@ -91,7 +97,12 @@ class Measure(object): block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name) block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name) - block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name) + block_db_txn_count.inc_by( + context.db_txn_count - self.db_txn_count, self.name + ) block_db_txn_duration.inc_by( context.db_txn_duration - self.db_txn_duration, self.name ) + + if self.created_context: + self.start_context.__exit__(exc_type, exc_val, exc_tb) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 983caafe8a..1f13cd0bc0 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -15,8 +15,6 @@ from twisted.internet import defer from tests import unittest -from synapse.replication.slave.storage.events import SlavedEventStore - from mock import Mock, NonCallableMock from tests.utils import setup_test_homeserver from synapse.replication.resource import ReplicationResource @@ -38,7 +36,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): self.replication = ReplicationResource(self.hs) self.master_store = self.hs.get_datastore() - self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs) + self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.event_id = 0 @defer.inlineCallbacks diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index baa4a26eb5..41a626cf70 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -16,6 +16,7 @@ from ._base import BaseSlavedStoreTestCase from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events.snapshot import EventContext +from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.roommember import RoomsForUser from twisted.internet import defer @@ -43,6 +44,8 @@ def patch__eq__(cls): class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): + STORE_TYPE = SlavedEventStore + def setUp(self): # Patch up the equality operator for events so that we can check # whether lists of events match using assertEquals @@ -251,6 +254,18 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) yield self.check("get_event", [msg.event_id], redacted) + @defer.inlineCallbacks + def test_invites(self): + yield self.check("get_invited_rooms_for_user", [USER_ID_2], []) + event = yield self.persist( + type="m.room.member", key=USER_ID_2, membership="invite" + ) + yield self.replicate() + yield self.check("get_invited_rooms_for_user", [USER_ID_2], [RoomsForUser( + ROOM_ID, USER_ID, "invite", event.event_id, + event.internal_metadata.stream_ordering + )]) + event_id = 0 @defer.inlineCallbacks diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py new file mode 100644 index 0000000000..6624fe4eea --- /dev/null +++ b/tests/replication/slave/storage/test_receipts.py @@ -0,0 +1,39 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStoreTestCase + +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore + +from twisted.internet import defer + +USER_ID = "@feeling:blue" +ROOM_ID = "!room:blue" +EVENT_ID = "$event:blue" + + +class SlavedReceiptTestCase(BaseSlavedStoreTestCase): + + STORE_TYPE = SlavedReceiptsStore + + @defer.inlineCallbacks + def test_receipt(self): + yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {}) + yield self.master_store.insert_receipt( + ROOM_ID, "m.read", USER_ID, [EVENT_ID], {} + ) + yield self.replicate() + yield self.check("get_receipts_for_user", [USER_ID, "m.read"], { + ROOM_ID: EVENT_ID + }) diff --git a/tests/test_state.py b/tests/test_state.py index a1ea7ef672..1a11bbcee0 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -140,13 +140,13 @@ class StateTestCase(unittest.TestCase): "add_event_hashes", ] ) - hs = Mock(spec=[ + hs = Mock(spec_set=[ "get_datastore", "get_auth", "get_state_handler", "get_clock", ]) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None - hs.get_auth.return_value = Auth(hs) hs.get_clock.return_value = MockClock() + hs.get_auth.return_value = Auth(hs) self.state = StateHandler(hs) self.event_id = 0 |