diff options
Diffstat (limited to 'synapse')
25 files changed, 1331 insertions, 498 deletions
diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 4748f71c2f..29eb012ddb 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -96,7 +96,7 @@ class TlsConfig(Config): # certificates returned by this server match one of the fingerprints. # # Synapse automatically adds the fingerprint of its own certificate - # to the list. So if federation traffic is handle directly by synapse + # to the list. So if federation traffic is handled directly by synapse # then no modification to the list is required. # # If synapse is run behind a load balancer that handles the TLS then it diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 3e7809b04f..9d39f46583 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -42,6 +42,8 @@ sent_edus_counter = client_metrics.register_counter("sent_edus") sent_transactions_counter = client_metrics.register_counter("sent_transactions") +events_processed_counter = client_metrics.register_counter("events_processed") + class TransactionQueue(object): """This class makes sure we only have one transaction in flight at @@ -205,6 +207,8 @@ class TransactionQueue(object): self._send_pdu(event, destinations) + events_processed_counter.inc_by(len(events)) + yield self.store.update_federation_out_pos( "events", next_token ) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index feca3e4c10..3dd3fa2a27 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -15,6 +15,7 @@ from twisted.internet import defer +import synapse from synapse.api.constants import EventTypes from synapse.util.metrics import Measure from synapse.util.logcontext import make_deferred_yieldable, preserve_fn @@ -23,6 +24,10 @@ import logging logger = logging.getLogger(__name__) +metrics = synapse.metrics.get_metrics_for(__name__) + +events_processed_counter = metrics.register_counter("events_processed") + def log_failure(failure): logger.error( @@ -103,6 +108,8 @@ class ApplicationServicesHandler(object): service, event ) + events_processed_counter.inc_by(len(events)) + yield self.store.set_appservice_last_pos(upper_bound) finally: self.is_processing = False diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index f7fad15c62..d996aa90bb 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -17,7 +17,8 @@ import logging from twisted.internet import defer -from synapse.types import get_domain_from_id +from synapse.api.errors import SynapseError +from synapse.types import get_domain_from_id, UserID from synapse.util.stringutils import random_string @@ -33,7 +34,7 @@ class DeviceMessageHandler(object): """ self.store = hs.get_datastore() self.notifier = hs.get_notifier() - self.is_mine_id = hs.is_mine_id + self.is_mine = hs.is_mine self.federation = hs.get_federation_sender() hs.get_replication_layer().register_edu_handler( @@ -52,6 +53,12 @@ class DeviceMessageHandler(object): message_type = content["type"] message_id = content["message_id"] for user_id, by_device in content["messages"].items(): + # we use UserID.from_string to catch invalid user ids + if not self.is_mine(UserID.from_string(user_id)): + logger.warning("Request for keys for non-local user %s", + user_id) + raise SynapseError(400, "Not a user here") + messages_by_device = { device_id: { "content": message_content, @@ -77,7 +84,8 @@ class DeviceMessageHandler(object): local_messages = {} remote_messages = {} for user_id, by_device in messages.items(): - if self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if self.is_mine(UserID.from_string(user_id)): messages_by_device = { device_id: { "content": message_content, diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 668a90e495..5af8abf66b 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -20,7 +20,7 @@ from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.api.errors import SynapseError, CodeMessageException -from synapse.types import get_domain_from_id +from synapse.types import get_domain_from_id, UserID from synapse.util.logcontext import preserve_fn, make_deferred_yieldable from synapse.util.retryutils import NotRetryingDestination @@ -32,7 +32,7 @@ class E2eKeysHandler(object): self.store = hs.get_datastore() self.federation = hs.get_replication_layer() self.device_handler = hs.get_device_handler() - self.is_mine_id = hs.is_mine_id + self.is_mine = hs.is_mine self.clock = hs.get_clock() # doesn't really work as part of the generic query API, because the @@ -70,7 +70,8 @@ class E2eKeysHandler(object): remote_queries = {} for user_id, device_ids in device_keys_query.items(): - if self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if self.is_mine(UserID.from_string(user_id)): local_query[user_id] = device_ids else: remote_queries[user_id] = device_ids @@ -170,7 +171,8 @@ class E2eKeysHandler(object): result_dict = {} for user_id, device_ids in query.items(): - if not self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if not self.is_mine(UserID.from_string(user_id)): logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") @@ -213,7 +215,8 @@ class E2eKeysHandler(object): remote_queries = {} for user_id, device_keys in query.get("one_time_keys", {}).items(): - if self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if self.is_mine(UserID.from_string(user_id)): for device_id, algorithm in device_keys.items(): local_query.append((user_id, device_id, algorithm)) else: diff --git a/synapse/http/server.py b/synapse/http/server.py index 6e8f4c9c5f..165c684d0d 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -42,36 +42,70 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) -incoming_requests_counter = metrics.register_counter( - "requests", +# total number of responses served, split by method/servlet/tag +response_count = metrics.register_counter( + "response_count", labels=["method", "servlet", "tag"], + alternative_names=( + # the following are all deprecated aliases for the same metric + metrics.name_prefix + x for x in ( + "_requests", + "_response_time:count", + "_response_ru_utime:count", + "_response_ru_stime:count", + "_response_db_txn_count:count", + "_response_db_txn_duration:count", + ) + ) ) + outgoing_responses_counter = metrics.register_counter( "responses", labels=["method", "code"], ) -response_timer = metrics.register_distribution( - "response_time", - labels=["method", "servlet", "tag"] +response_timer = metrics.register_counter( + "response_time_seconds", + labels=["method", "servlet", "tag"], + alternative_names=( + metrics.name_prefix + "_response_time:total", + ), ) -response_ru_utime = metrics.register_distribution( - "response_ru_utime", labels=["method", "servlet", "tag"] +response_ru_utime = metrics.register_counter( + "response_ru_utime_seconds", labels=["method", "servlet", "tag"], + alternative_names=( + metrics.name_prefix + "_response_ru_utime:total", + ), ) -response_ru_stime = metrics.register_distribution( - "response_ru_stime", labels=["method", "servlet", "tag"] +response_ru_stime = metrics.register_counter( + "response_ru_stime_seconds", labels=["method", "servlet", "tag"], + alternative_names=( + metrics.name_prefix + "_response_ru_stime:total", + ), ) -response_db_txn_count = metrics.register_distribution( - "response_db_txn_count", labels=["method", "servlet", "tag"] +response_db_txn_count = metrics.register_counter( + "response_db_txn_count", labels=["method", "servlet", "tag"], + alternative_names=( + metrics.name_prefix + "_response_db_txn_count:total", + ), ) -response_db_txn_duration = metrics.register_distribution( - "response_db_txn_duration", labels=["method", "servlet", "tag"] +# seconds spent waiting for db txns, excluding scheduling time, when processing +# this request +response_db_txn_duration = metrics.register_counter( + "response_db_txn_duration_seconds", labels=["method", "servlet", "tag"], + alternative_names=( + metrics.name_prefix + "_response_db_txn_duration:total", + ), ) +# seconds spent waiting for a db connection, when processing this request +response_db_sched_duration = metrics.register_counter( + "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"] +) _next_request_id = 0 @@ -288,15 +322,6 @@ class JsonResource(HttpServer, resource.Resource): def _send_response(self, request, code, response_json_object, response_code_message=None): - # could alternatively use request.notifyFinish() and flip a flag when - # the Deferred fires, but since the flag is RIGHT THERE it seems like - # a waste. - if request._disconnected: - logger.warn( - "Not sending response to request %s, already disconnected.", - request) - return - outgoing_responses_counter.inc(request.method, str(code)) # TODO: Only enable CORS for the requests that need it. @@ -330,7 +355,7 @@ class RequestMetrics(object): ) return - incoming_requests_counter.inc(request.method, self.name, tag) + response_count.inc(request.method, self.name, tag) response_timer.inc_by( clock.time_msec() - self.start, request.method, @@ -349,7 +374,10 @@ class RequestMetrics(object): context.db_txn_count, request.method, self.name, tag ) response_db_txn_duration.inc_by( - context.db_txn_duration, request.method, self.name, tag + context.db_txn_duration_ms / 1000., request.method, self.name, tag + ) + response_db_sched_duration.inc_by( + context.db_sched_duration_ms / 1000., request.method, self.name, tag ) @@ -372,6 +400,15 @@ class RootRedirect(resource.Resource): def respond_with_json(request, code, json_object, send_cors=False, response_code_message=None, pretty_print=False, version_string="", canonical_json=True): + # could alternatively use request.notifyFinish() and flip a flag when + # the Deferred fires, but since the flag is RIGHT THERE it seems like + # a waste. + if request._disconnected: + logger.warn( + "Not sending response to request %s, already disconnected.", + request) + return + if pretty_print: json_bytes = encode_pretty_printed_json(json_object) + "\n" else: diff --git a/synapse/http/site.py b/synapse/http/site.py index cd1492b1c3..e422c8dfae 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -66,14 +66,15 @@ class SynapseRequest(Request): context = LoggingContext.current_context() ru_utime, ru_stime = context.get_resource_usage() db_txn_count = context.db_txn_count - db_txn_duration = context.db_txn_duration + db_txn_duration_ms = context.db_txn_duration_ms + db_sched_duration_ms = context.db_sched_duration_ms except Exception: ru_utime, ru_stime = (0, 0) - db_txn_count, db_txn_duration = (0, 0) + db_txn_count, db_txn_duration_ms = (0, 0) self.site.access_logger.info( "%s - %s - {%s}" - " Processed request: %dms (%dms, %dms) (%dms/%d)" + " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)" " %sB %s \"%s %s %s\" \"%s\"", self.getClientIP(), self.site.site_tag, @@ -81,7 +82,8 @@ class SynapseRequest(Request): int(time.time() * 1000) - self.start_time, int(ru_utime * 1000), int(ru_stime * 1000), - int(db_txn_duration * 1000), + db_sched_duration_ms, + db_txn_duration_ms, int(db_txn_count), self.sentLength, self.code, diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py index e87b2b80a7..f480aae614 100644 --- a/synapse/metrics/metric.py +++ b/synapse/metrics/metric.py @@ -17,16 +17,33 @@ from itertools import chain -# TODO(paul): I can't believe Python doesn't have one of these -def map_concat(func, items): - # flatten a list-of-lists - return list(chain.from_iterable(map(func, items))) +def flatten(items): + """Flatten a list of lists + + Args: + items: iterable[iterable[X]] + + Returns: + list[X]: flattened list + """ + return list(chain.from_iterable(items)) class BaseMetric(object): + """Base class for metrics which report a single value per label set + """ - def __init__(self, name, labels=[]): - self.name = name + def __init__(self, name, labels=[], alternative_names=[]): + """ + Args: + name (str): principal name for this metric + labels (list(str)): names of the labels which will be reported + for this metric + alternative_names (iterable(str)): list of alternative names for + this metric. This can be useful to provide a migration path + when renaming metrics. + """ + self._names = [name] + list(alternative_names) self.labels = labels # OK not to clone as we never write it def dimension(self): @@ -36,7 +53,7 @@ class BaseMetric(object): return not len(self.labels) def _render_labelvalue(self, value): - # TODO: some kind of value escape + # TODO: escape backslashes, quotes and newlines return '"%s"' % (value) def _render_key(self, values): @@ -47,19 +64,60 @@ class BaseMetric(object): for k, v in zip(self.labels, values)]) ) + def _render_for_labels(self, label_values, value): + """Render this metric for a single set of labels + + Args: + label_values (list[str]): values for each of the labels + value: value of the metric at with these labels + + Returns: + iterable[str]: rendered metric + """ + rendered_labels = self._render_key(label_values) + return ( + "%s%s %.12g" % (name, rendered_labels, value) + for name in self._names + ) + + def render(self): + """Render this metric + + Each metric is rendered as: + + name{label1="val1",label2="val2"} value + + https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details + + Returns: + iterable[str]: rendered metrics + """ + raise NotImplementedError() + class CounterMetric(BaseMetric): """The simplest kind of metric; one that stores a monotonically-increasing - integer that counts events.""" + value that counts events or running totals. + + Example use cases for Counters: + - Number of requests processed + - Number of items that were inserted into a queue + - Total amount of data that a system has processed + Counters can only go up (and be reset when the process restarts). + """ def __init__(self, *args, **kwargs): super(CounterMetric, self).__init__(*args, **kwargs) + # dict[list[str]]: value for each set of label values. the keys are the + # label values, in the same order as the labels in self.labels. + # + # (if the metric is a scalar, the (single) key is the empty list). self.counts = {} # Scalar metrics are never empty if self.is_scalar(): - self.counts[()] = 0 + self.counts[()] = 0. def inc_by(self, incr, *values): if len(values) != self.dimension(): @@ -77,11 +135,11 @@ class CounterMetric(BaseMetric): def inc(self, *values): self.inc_by(1, *values) - def render_item(self, k): - return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])] - def render(self): - return map_concat(self.render_item, sorted(self.counts.keys())) + return flatten( + self._render_for_labels(k, self.counts[k]) + for k in sorted(self.counts.keys()) + ) class CallbackMetric(BaseMetric): @@ -98,10 +156,12 @@ class CallbackMetric(BaseMetric): value = self.callback() if self.is_scalar(): - return ["%s %.12g" % (self.name, value)] + return list(self._render_for_labels([], value)) - return ["%s%s %.12g" % (self.name, self._render_key(k), value[k]) - for k in sorted(value.keys())] + return flatten( + self._render_for_labels(k, value[k]) + for k in sorted(value.keys()) + ) class DistributionMetric(object): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index d59503b905..0a9a290af4 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -517,25 +517,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_error("Wrong remote") def on_RDATA(self, cmd): + stream_name = cmd.stream_name + inbound_rdata_count.inc(stream_name) + try: - row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row) + row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row) except Exception: logger.exception( "[%s] Failed to parse RDATA: %r %r", - self.id(), cmd.stream_name, cmd.row + self.id(), stream_name, cmd.row ) raise if cmd.token is None: # I.e. this is part of a batch of updates for this stream. Batch # until we get an update for the stream with a non None token - self.pending_batches.setdefault(cmd.stream_name, []).append(row) + self.pending_batches.setdefault(stream_name, []).append(row) else: # Check if this is the last of a batch of updates - rows = self.pending_batches.pop(cmd.stream_name, []) + rows = self.pending_batches.pop(stream_name, []) rows.append(row) - self.handler.on_rdata(cmd.stream_name, cmd.token, rows) + self.handler.on_rdata(stream_name, cmd.token, rows) def on_POSITION(self, cmd): self.handler.on_position(cmd.stream_name, cmd.token) @@ -644,3 +647,9 @@ metrics.register_callback( }, labels=["command", "name", "conn_id"], ) + +# number of updates received for each RDATA stream +inbound_rdata_count = metrics.register_counter( + "inbound_rdata_count", + labels=["stream_name"], +) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 95fa95fce3..e7ac01da01 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -70,38 +70,11 @@ def respond_with_file(request, media_type, file_path, logger.debug("Responding with %r", file_path) if os.path.isfile(file_path): - request.setHeader(b"Content-Type", media_type.encode("UTF-8")) - if upload_name: - if is_ascii(upload_name): - request.setHeader( - b"Content-Disposition", - b"inline; filename=%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) - else: - request.setHeader( - b"Content-Disposition", - b"inline; filename*=utf-8''%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) - - # cache for at least a day. - # XXX: we might want to turn this off for data we don't want to - # recommend caching as it's sensitive or private - or at least - # select private. don't bother setting Expires as all our - # clients are smart enough to be happy with Cache-Control - request.setHeader( - b"Cache-Control", b"public,max-age=86400,s-maxage=86400" - ) if file_size is None: stat = os.stat(file_path) file_size = stat.st_size - request.setHeader( - b"Content-Length", b"%d" % (file_size,) - ) + add_file_headers(request, media_type, file_size, upload_name) with open(file_path, "rb") as f: yield logcontext.make_deferred_yieldable( @@ -111,3 +84,118 @@ def respond_with_file(request, media_type, file_path, finish_request(request) else: respond_404(request) + + +def add_file_headers(request, media_type, file_size, upload_name): + """Adds the correct response headers in preparation for responding with the + media. + + Args: + request (twisted.web.http.Request) + media_type (str): The media/content type. + file_size (int): Size in bytes of the media, if known. + upload_name (str): The name of the requested file, if any. + """ + request.setHeader(b"Content-Type", media_type.encode("UTF-8")) + if upload_name: + if is_ascii(upload_name): + request.setHeader( + b"Content-Disposition", + b"inline; filename=%s" % ( + urllib.quote(upload_name.encode("utf-8")), + ), + ) + else: + request.setHeader( + b"Content-Disposition", + b"inline; filename*=utf-8''%s" % ( + urllib.quote(upload_name.encode("utf-8")), + ), + ) + + # cache for at least a day. + # XXX: we might want to turn this off for data we don't want to + # recommend caching as it's sensitive or private - or at least + # select private. don't bother setting Expires as all our + # clients are smart enough to be happy with Cache-Control + request.setHeader( + b"Cache-Control", b"public,max-age=86400,s-maxage=86400" + ) + + request.setHeader( + b"Content-Length", b"%d" % (file_size,) + ) + + +@defer.inlineCallbacks +def respond_with_responder(request, responder, media_type, file_size, upload_name=None): + """Responds to the request with given responder. If responder is None then + returns 404. + + Args: + request (twisted.web.http.Request) + responder (Responder|None) + media_type (str): The media/content type. + file_size (int|None): Size in bytes of the media. If not known it should be None + upload_name (str|None): The name of the requested file, if any. + """ + if not responder: + respond_404(request) + return + + add_file_headers(request, media_type, file_size, upload_name) + with responder: + yield responder.write_to_consumer(request) + finish_request(request) + + +class Responder(object): + """Represents a response that can be streamed to the requester. + + Responder is a context manager which *must* be used, so that any resources + held can be cleaned up. + """ + def write_to_consumer(self, consumer): + """Stream response into consumer + + Args: + consumer (IConsumer) + + Returns: + Deferred: Resolves once the response has finished being written + """ + pass + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +class FileInfo(object): + """Details about a requested/uploaded file. + + Attributes: + server_name (str): The server name where the media originated from, + or None if local. + file_id (str): The local ID of the file. For local files this is the + same as the media_id + url_cache (bool): If the file is for the url preview cache + thumbnail (bool): Whether the file is a thumbnail or not. + thumbnail_width (int) + thumbnail_height (int) + thumbnail_method (str) + thumbnail_type (str): Content type of thumbnail, e.g. image/png + """ + def __init__(self, server_name, file_id, url_cache=False, + thumbnail=False, thumbnail_width=None, thumbnail_height=None, + thumbnail_method=None, thumbnail_type=None): + self.server_name = server_name + self.file_id = file_id + self.url_cache = url_cache + self.thumbnail = thumbnail + self.thumbnail_width = thumbnail_width + self.thumbnail_height = thumbnail_height + self.thumbnail_method = thumbnail_method + self.thumbnail_type = thumbnail_type diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 6879249c8a..fe7e17596f 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -14,7 +14,7 @@ # limitations under the License. import synapse.http.servlet -from ._base import parse_media_id, respond_with_file, respond_404 +from ._base import parse_media_id, respond_404 from twisted.web.resource import Resource from synapse.http.server import request_handler, set_cors_headers @@ -32,12 +32,12 @@ class DownloadResource(Resource): def __init__(self, hs, media_repo): Resource.__init__(self) - self.filepaths = media_repo.filepaths self.media_repo = media_repo self.server_name = hs.hostname - self.store = hs.get_datastore() - self.version_string = hs.version_string + + # Both of these are expected by @request_handler() self.clock = hs.get_clock() + self.version_string = hs.version_string def render_GET(self, request): self._async_render_GET(request) @@ -57,59 +57,16 @@ class DownloadResource(Resource): ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: - yield self._respond_local_file(request, media_id, name) + yield self.media_repo.get_local_media(request, media_id, name) else: - yield self._respond_remote_file( - request, server_name, media_id, name - ) - - @defer.inlineCallbacks - def _respond_local_file(self, request, media_id, name): - media_info = yield self.store.get_local_media(media_id) - if not media_info or media_info["quarantined_by"]: - respond_404(request) - return - - media_type = media_info["media_type"] - media_length = media_info["media_length"] - upload_name = name if name else media_info["upload_name"] - if media_info["url_cache"]: - # TODO: Check the file still exists, if it doesn't we can redownload - # it from the url `media_info["url_cache"]` - file_path = self.filepaths.url_cache_filepath(media_id) - else: - file_path = self.filepaths.local_media_filepath(media_id) - - yield respond_with_file( - request, media_type, file_path, media_length, - upload_name=upload_name, - ) - - @defer.inlineCallbacks - def _respond_remote_file(self, request, server_name, media_id, name): - # don't forward requests for remote media if allow_remote is false - allow_remote = synapse.http.servlet.parse_boolean( - request, "allow_remote", default=True) - if not allow_remote: - logger.info( - "Rejecting request for remote media %s/%s due to allow_remote", - server_name, media_id, - ) - respond_404(request) - return - - media_info = yield self.media_repo.get_remote_media(server_name, media_id) - - media_type = media_info["media_type"] - media_length = media_info["media_length"] - filesystem_id = media_info["filesystem_id"] - upload_name = name if name else media_info["upload_name"] - - file_path = self.filepaths.remote_media_filepath( - server_name, filesystem_id - ) - - yield respond_with_file( - request, media_type, file_path, media_length, - upload_name=upload_name, - ) + allow_remote = synapse.http.servlet.parse_boolean( + request, "allow_remote", default=True) + if not allow_remote: + logger.info( + "Rejecting request for remote media %s/%s due to allow_remote", + server_name, media_id, + ) + respond_404(request) + return + + yield self.media_repo.get_remote_media(request, server_name, media_id, name) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index eed9056a2f..b2c76440b7 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +19,7 @@ import twisted.internet.error import twisted.web.http from twisted.web.resource import Resource +from ._base import respond_404, FileInfo, respond_with_responder from .upload_resource import UploadResource from .download_resource import DownloadResource from .thumbnail_resource import ThumbnailResource @@ -25,6 +27,10 @@ from .identicon_resource import IdenticonResource from .preview_url_resource import PreviewUrlResource from .filepath import MediaFilePaths from .thumbnailer import Thumbnailer +from .storage_provider import ( + StorageProviderWrapper, FileStorageProviderBackend, +) +from .media_storage import MediaStorage from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.util.stringutils import random_string @@ -33,7 +39,7 @@ from synapse.api.errors import SynapseError, HttpResponseException, \ from synapse.util.async import Linearizer from synapse.util.stringutils import is_ascii -from synapse.util.logcontext import make_deferred_yieldable, preserve_fn +from synapse.util.logcontext import make_deferred_yieldable from synapse.util.retryutils import NotRetryingDestination import os @@ -47,7 +53,7 @@ import urlparse logger = logging.getLogger(__name__) -UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000 +UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 class MediaRepository(object): @@ -63,96 +69,64 @@ class MediaRepository(object): self.primary_base_path = hs.config.media_store_path self.filepaths = MediaFilePaths(self.primary_base_path) - self.backup_base_path = hs.config.backup_media_store_path - - self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store - self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() + self.recently_accessed_locals = set() - self.clock.looping_call( - self._update_recently_accessed_remotes, - UPDATE_RECENTLY_ACCESSED_REMOTES_TS - ) + # List of StorageProviders where we should search for media and + # potentially upload to. + storage_providers = [] - @defer.inlineCallbacks - def _update_recently_accessed_remotes(self): - media = self.recently_accessed_remotes - self.recently_accessed_remotes = set() + # TODO: Move this into config and allow other storage providers to be + # defined. + if hs.config.backup_media_store_path: + backend = FileStorageProviderBackend( + self.primary_base_path, hs.config.backup_media_store_path, + ) + provider = StorageProviderWrapper( + backend, + store=True, + store_synchronous=hs.config.synchronous_backup_media_store, + store_remote=True, + ) + storage_providers.append(provider) - yield self.store.update_cached_last_access_time( - media, self.clock.time_msec() + self.media_storage = MediaStorage( + self.primary_base_path, self.filepaths, storage_providers, ) - @staticmethod - def _makedirs(filepath): - dirname = os.path.dirname(filepath) - if not os.path.exists(dirname): - os.makedirs(dirname) - - @staticmethod - def _write_file_synchronously(source, fname): - """Write `source` to the path `fname` synchronously. Should be called - from a thread. - - Args: - source: A file like object to be written - fname (str): Path to write to - """ - MediaRepository._makedirs(fname) - source.seek(0) # Ensure we read from the start of the file - with open(fname, "wb") as f: - shutil.copyfileobj(source, f) + self.clock.looping_call( + self._update_recently_accessed, + UPDATE_RECENTLY_ACCESSED_TS, + ) @defer.inlineCallbacks - def write_to_file_and_backup(self, source, path): - """Write `source` to the on disk media store, and also the backup store - if configured. - - Args: - source: A file like object that should be written - path (str): Relative path to write file to - - Returns: - Deferred[str]: the file path written to in the primary media store - """ - fname = os.path.join(self.primary_base_path, path) - - # Write to the main repository - yield make_deferred_yieldable(threads.deferToThread( - self._write_file_synchronously, source, fname, - )) + def _update_recently_accessed(self): + remote_media = self.recently_accessed_remotes + self.recently_accessed_remotes = set() - # Write to backup repository - yield self.copy_to_backup(path) + local_media = self.recently_accessed_locals + self.recently_accessed_locals = set() - defer.returnValue(fname) + yield self.store.update_cached_last_access_time( + local_media, remote_media, self.clock.time_msec() + ) - @defer.inlineCallbacks - def copy_to_backup(self, path): - """Copy a file from the primary to backup media store, if configured. + def mark_recently_accessed(self, server_name, media_id): + """Mark the given media as recently accessed. Args: - path(str): Relative path to write file to + server_name (str|None): Origin server of media, or None if local + media_id (str): The media ID of the content """ - if self.backup_base_path: - primary_fname = os.path.join(self.primary_base_path, path) - backup_fname = os.path.join(self.backup_base_path, path) - - # We can either wait for successful writing to the backup repository - # or write in the background and immediately return - if self.synchronous_backup_media_store: - yield make_deferred_yieldable(threads.deferToThread( - shutil.copyfile, primary_fname, backup_fname, - )) - else: - preserve_fn(threads.deferToThread)( - shutil.copyfile, primary_fname, backup_fname, - ) + if server_name: + self.recently_accessed_remotes.add((server_name, media_id)) + else: + self.recently_accessed_locals.add(media_id) @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, @@ -171,10 +145,13 @@ class MediaRepository(object): """ media_id = random_string(24) - fname = yield self.write_to_file_and_backup( - content, self.filepaths.local_media_filepath_rel(media_id) + file_info = FileInfo( + server_name=None, + file_id=media_id, ) + fname = yield self.media_storage.store_file(content, file_info) + logger.info("Stored local media in file %r", fname) yield self.store.store_local_media( @@ -185,134 +162,263 @@ class MediaRepository(object): media_length=content_length, user_id=auth_user, ) - media_info = { - "media_type": media_type, - "media_length": content_length, - } - yield self._generate_thumbnails(None, media_id, media_info) + yield self._generate_thumbnails( + None, media_id, media_id, media_type, + ) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @defer.inlineCallbacks - def get_remote_media(self, server_name, media_id): + def get_local_media(self, request, media_id, name): + """Responds to reqests for local media, if exists, or returns 404. + + Args: + request(twisted.web.http.Request) + media_id (str): The media ID of the content. (This is the same as + the file_id for local content.) + name (str|None): Optional name that, if specified, will be used as + the filename in the Content-Disposition header of the response. + + Returns: + Deferred: Resolves once a response has successfully been written + to request + """ + media_info = yield self.store.get_local_media(media_id) + if not media_info or media_info["quarantined_by"]: + respond_404(request) + return + + self.mark_recently_accessed(None, media_id) + + media_type = media_info["media_type"] + media_length = media_info["media_length"] + upload_name = name if name else media_info["upload_name"] + url_cache = media_info["url_cache"] + + file_info = FileInfo( + None, media_id, + url_cache=url_cache, + ) + + responder = yield self.media_storage.fetch_media(file_info) + yield respond_with_responder( + request, responder, media_type, media_length, upload_name, + ) + + @defer.inlineCallbacks + def get_remote_media(self, request, server_name, media_id, name): + """Respond to requests for remote media. + + Args: + request(twisted.web.http.Request) + server_name (str): Remote server_name where the media originated. + media_id (str): The media ID of the content (as defined by the + remote server). + name (str|None): Optional name that, if specified, will be used as + the filename in the Content-Disposition header of the response. + + Returns: + Deferred: Resolves once a response has successfully been written + to request + """ + self.mark_recently_accessed(server_name, media_id) + + # We linearize here to ensure that we don't try and download remote + # media multiple times concurrently + key = (server_name, media_id) + with (yield self.remote_media_linearizer.queue(key)): + responder, media_info = yield self._get_remote_media_impl( + server_name, media_id, + ) + + # We deliberately stream the file outside the lock + if responder: + media_type = media_info["media_type"] + media_length = media_info["media_length"] + upload_name = name if name else media_info["upload_name"] + yield respond_with_responder( + request, responder, media_type, media_length, upload_name, + ) + else: + respond_404(request) + + @defer.inlineCallbacks + def get_remote_media_info(self, server_name, media_id): + """Gets the media info associated with the remote file, downloading + if necessary. + + Args: + server_name (str): Remote server_name where the media originated. + media_id (str): The media ID of the content (as defined by the + remote server). + + Returns: + Deferred[dict]: The media_info of the file + """ + # We linearize here to ensure that we don't try and download remote + # media multiple times concurrently key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): - media_info = yield self._get_remote_media_impl(server_name, media_id) + responder, media_info = yield self._get_remote_media_impl( + server_name, media_id, + ) + + # Ensure we actually use the responder so that it releases resources + if responder: + with responder: + pass + defer.returnValue(media_info) @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): + """Looks for media in local cache, if not there then attempt to + download from remote server. + + Args: + server_name (str): Remote server_name where the media originated. + media_id (str): The media ID of the content (as defined by the + remote server). + + Returns: + Deferred[(Responder, media_info)] + """ media_info = yield self.store.get_cached_remote_media( server_name, media_id ) - if not media_info: - media_info = yield self._download_remote_file( - server_name, media_id - ) - elif media_info["quarantined_by"]: - raise NotFoundError() + + # file_id is the ID we use to track the file locally. If we've already + # seen the file then reuse the existing ID, otherwise genereate a new + # one. + if media_info: + file_id = media_info["filesystem_id"] else: - self.recently_accessed_remotes.add((server_name, media_id)) - yield self.store.update_cached_last_access_time( - [(server_name, media_id)], self.clock.time_msec() - ) - defer.returnValue(media_info) + file_id = random_string(24) + + file_info = FileInfo(server_name, file_id) + + # If we have an entry in the DB, try and look for it + if media_info: + if media_info["quarantined_by"]: + logger.info("Media is quarantined") + raise NotFoundError() + + responder = yield self.media_storage.fetch_media(file_info) + if responder: + defer.returnValue((responder, media_info)) + + # Failed to find the file anywhere, lets download it. + + media_info = yield self._download_remote_file( + server_name, media_id, file_id + ) + + responder = yield self.media_storage.fetch_media(file_info) + defer.returnValue((responder, media_info)) @defer.inlineCallbacks - def _download_remote_file(self, server_name, media_id): - file_id = random_string(24) + def _download_remote_file(self, server_name, media_id, file_id): + """Attempt to download the remote file from the given server name, + using the given file_id as the local id. + + Args: + server_name (str): Originating server + media_id (str): The media ID of the content (as defined by the + remote server). This is different than the file_id, which is + locally generated. + file_id (str): Local file ID + + Returns: + Deferred[MediaInfo] + """ - fpath = self.filepaths.remote_media_filepath_rel( - server_name, file_id + file_info = FileInfo( + server_name=server_name, + file_id=file_id, ) - fname = os.path.join(self.primary_base_path, fpath) - self._makedirs(fname) - try: - with open(fname, "wb") as f: - request_path = "/".join(( - "/_matrix/media/v1/download", server_name, media_id, - )) + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + request_path = "/".join(( + "/_matrix/media/v1/download", server_name, media_id, + )) + try: + length, headers = yield self.client.get_file( + server_name, request_path, output_stream=f, + max_size=self.max_upload_size, args={ + # tell the remote server to 404 if it doesn't + # recognise the server_name, to make sure we don't + # end up with a routing loop. + "allow_remote": "false", + } + ) + except twisted.internet.error.DNSLookupError as e: + logger.warn("HTTP error fetching remote media %s/%s: %r", + server_name, media_id, e) + raise NotFoundError() + + except HttpResponseException as e: + logger.warn("HTTP error fetching remote media %s/%s: %s", + server_name, media_id, e.response) + if e.code == twisted.web.http.NOT_FOUND: + raise SynapseError.from_http_response_exception(e) + raise SynapseError(502, "Failed to fetch remote media") + + except SynapseError: + logger.exception("Failed to fetch remote media %s/%s", + server_name, media_id) + raise + except NotRetryingDestination: + logger.warn("Not retrying destination %r", server_name) + raise SynapseError(502, "Failed to fetch remote media") + except Exception: + logger.exception("Failed to fetch remote media %s/%s", + server_name, media_id) + raise SynapseError(502, "Failed to fetch remote media") + + yield finish() + + media_type = headers["Content-Type"][0] + + time_now_ms = self.clock.time_msec() + + content_disposition = headers.get("Content-Disposition", None) + if content_disposition: + _, params = cgi.parse_header(content_disposition[0],) + upload_name = None + + # First check if there is a valid UTF-8 filename + upload_name_utf8 = params.get("filename*", None) + if upload_name_utf8: + if upload_name_utf8.lower().startswith("utf-8''"): + upload_name = upload_name_utf8[7:] + + # If there isn't check for an ascii name. + if not upload_name: + upload_name_ascii = params.get("filename", None) + if upload_name_ascii and is_ascii(upload_name_ascii): + upload_name = upload_name_ascii + + if upload_name: + upload_name = urlparse.unquote(upload_name) try: - length, headers = yield self.client.get_file( - server_name, request_path, output_stream=f, - max_size=self.max_upload_size, args={ - # tell the remote server to 404 if it doesn't - # recognise the server_name, to make sure we don't - # end up with a routing loop. - "allow_remote": "false", - } - ) - except twisted.internet.error.DNSLookupError as e: - logger.warn("HTTP error fetching remote media %s/%s: %r", - server_name, media_id, e) - raise NotFoundError() - - except HttpResponseException as e: - logger.warn("HTTP error fetching remote media %s/%s: %s", - server_name, media_id, e.response) - if e.code == twisted.web.http.NOT_FOUND: - raise SynapseError.from_http_response_exception(e) - raise SynapseError(502, "Failed to fetch remote media") - - except SynapseError: - logger.exception("Failed to fetch remote media %s/%s", - server_name, media_id) - raise - except NotRetryingDestination: - logger.warn("Not retrying destination %r", server_name) - raise SynapseError(502, "Failed to fetch remote media") - except Exception: - logger.exception("Failed to fetch remote media %s/%s", - server_name, media_id) - raise SynapseError(502, "Failed to fetch remote media") - - yield self.copy_to_backup(fpath) - - media_type = headers["Content-Type"][0] - time_now_ms = self.clock.time_msec() - - content_disposition = headers.get("Content-Disposition", None) - if content_disposition: - _, params = cgi.parse_header(content_disposition[0],) - upload_name = None - - # First check if there is a valid UTF-8 filename - upload_name_utf8 = params.get("filename*", None) - if upload_name_utf8: - if upload_name_utf8.lower().startswith("utf-8''"): - upload_name = upload_name_utf8[7:] - - # If there isn't check for an ascii name. - if not upload_name: - upload_name_ascii = params.get("filename", None) - if upload_name_ascii and is_ascii(upload_name_ascii): - upload_name = upload_name_ascii - - if upload_name: - upload_name = urlparse.unquote(upload_name) - try: - upload_name = upload_name.decode("utf-8") - except UnicodeDecodeError: - upload_name = None - else: - upload_name = None - - logger.info("Stored remote media in file %r", fname) - - yield self.store.store_cached_remote_media( - origin=server_name, - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=length, - filesystem_id=file_id, - ) - except Exception: - os.remove(fname) - raise + upload_name = upload_name.decode("utf-8") + except UnicodeDecodeError: + upload_name = None + else: + upload_name = None + + logger.info("Stored remote media in file %r", fname) + + yield self.store.store_cached_remote_media( + origin=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=length, + filesystem_id=file_id, + ) media_info = { "media_type": media_type, @@ -323,7 +429,7 @@ class MediaRepository(object): } yield self._generate_thumbnails( - server_name, media_id, media_info + server_name, media_id, file_id, media_type, ) defer.returnValue(media_info) @@ -368,11 +474,18 @@ class MediaRepository(object): if t_byte_source: try: - output_path = yield self.write_to_file_and_backup( - t_byte_source, - self.filepaths.local_media_thumbnail_rel( - media_id, t_width, t_height, t_type, t_method - ) + file_info = FileInfo( + server_name=None, + file_id=media_id, + thumbnail=True, + thumbnail_width=t_width, + thumbnail_height=t_height, + thumbnail_method=t_method, + thumbnail_type=t_type, + ) + + output_path = yield self.media_storage.store_file( + t_byte_source, file_info, ) finally: t_byte_source.close() @@ -400,11 +513,18 @@ class MediaRepository(object): if t_byte_source: try: - output_path = yield self.write_to_file_and_backup( - t_byte_source, - self.filepaths.remote_media_thumbnail_rel( - server_name, file_id, t_width, t_height, t_type, t_method - ) + file_info = FileInfo( + server_name=server_name, + file_id=media_id, + thumbnail=True, + thumbnail_width=t_width, + thumbnail_height=t_height, + thumbnail_method=t_method, + thumbnail_type=t_type, + ) + + output_path = yield self.media_storage.store_file( + t_byte_source, file_info, ) finally: t_byte_source.close() @@ -421,21 +541,22 @@ class MediaRepository(object): defer.returnValue(output_path) @defer.inlineCallbacks - def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False): + def _generate_thumbnails(self, server_name, media_id, file_id, media_type, + url_cache=False): """Generate and store thumbnails for an image. Args: - server_name(str|None): The server name if remote media, else None if local - media_id(str) - media_info(dict) - url_cache(bool): If we are thumbnailing images downloaded for the URL cache, + server_name (str|None): The server name if remote media, else None if local + media_id (str): The media ID of the content. (This is the same as + the file_id for local content) + file_id (str): Local file ID + media_type (str): The content type of the file + url_cache (bool): If we are thumbnailing images downloaded for the URL cache, used exclusively by the url previewer Returns: Deferred[dict]: Dict with "width" and "height" keys of original image """ - media_type = media_info["media_type"] - file_id = media_info.get("filesystem_id") requirements = self._get_thumbnail_requirements(media_type) if not requirements: return @@ -472,20 +593,6 @@ class MediaRepository(object): # Now we generate the thumbnails for each dimension, store it for (t_width, t_height, t_type), t_method in thumbnails.iteritems(): - # Work out the correct file name for thumbnail - if server_name: - file_path = self.filepaths.remote_media_thumbnail_rel( - server_name, file_id, t_width, t_height, t_type, t_method - ) - elif url_cache: - file_path = self.filepaths.url_cache_thumbnail_rel( - media_id, t_width, t_height, t_type, t_method - ) - else: - file_path = self.filepaths.local_media_thumbnail_rel( - media_id, t_width, t_height, t_type, t_method - ) - # Generate the thumbnail if t_method == "crop": t_byte_source = yield make_deferred_yieldable(threads.deferToThread( @@ -505,9 +612,19 @@ class MediaRepository(object): continue try: - # Write to disk - output_path = yield self.write_to_file_and_backup( - t_byte_source, file_path, + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + thumbnail=True, + thumbnail_width=t_width, + thumbnail_height=t_height, + thumbnail_method=t_method, + thumbnail_type=t_type, + url_cache=url_cache, + ) + + output_path = yield self.media_storage.store_file( + t_byte_source, file_info, ) finally: t_byte_source.close() @@ -620,7 +737,11 @@ class MediaRepositoryResource(Resource): self.putChild("upload", UploadResource(hs, media_repo)) self.putChild("download", DownloadResource(hs, media_repo)) - self.putChild("thumbnail", ThumbnailResource(hs, media_repo)) + self.putChild("thumbnail", ThumbnailResource( + hs, media_repo, media_repo.media_storage, + )) self.putChild("identicon", IdenticonResource()) if hs.config.url_preview_enabled: - self.putChild("preview_url", PreviewUrlResource(hs, media_repo)) + self.putChild("preview_url", PreviewUrlResource( + hs, media_repo, media_repo.media_storage, + )) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py new file mode 100644 index 0000000000..001e84578e --- /dev/null +++ b/synapse/rest/media/v1/media_storage.py @@ -0,0 +1,228 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vecotr Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer, threads +from twisted.protocols.basic import FileSender + +from ._base import Responder + +from synapse.util.logcontext import make_deferred_yieldable + +import contextlib +import os +import logging +import shutil +import sys + +logger = logging.getLogger(__name__) + + +class MediaStorage(object): + """Responsible for storing/fetching files from local sources. + + Args: + local_media_directory (str): Base path where we store media on disk + filepaths (MediaFilePaths) + storage_providers ([StorageProvider]): List of StorageProvider that are + used to fetch and store files. + """ + + def __init__(self, local_media_directory, filepaths, storage_providers): + self.local_media_directory = local_media_directory + self.filepaths = filepaths + self.storage_providers = storage_providers + + @defer.inlineCallbacks + def store_file(self, source, file_info): + """Write `source` to the on disk media store, and also any other + configured storage providers + + Args: + source: A file like object that should be written + file_info (FileInfo): Info about the file to store + + Returns: + Deferred[str]: the file path written to in the primary media store + """ + path = self._file_info_to_path(file_info) + fname = os.path.join(self.local_media_directory, path) + + dirname = os.path.dirname(fname) + if not os.path.exists(dirname): + os.makedirs(dirname) + + # Write to the main repository + yield make_deferred_yieldable(threads.deferToThread( + _write_file_synchronously, source, fname, + )) + + defer.returnValue(fname) + + @contextlib.contextmanager + def store_into_file(self, file_info): + """Context manager used to get a file like object to write into, as + described by file_info. + + Actually yields a 3-tuple (file, fname, finish_cb), where file is a file + like object that can be written to, fname is the absolute path of file + on disk, and finish_cb is a function that returns a Deferred. + + fname can be used to read the contents from after upload, e.g. to + generate thumbnails. + + finish_cb must be called and waited on after the file has been + successfully been written to. Should not be called if there was an + error. + + Args: + file_info (FileInfo): Info about the file to store + + Example: + + with media_storage.store_into_file(info) as (f, fname, finish_cb): + # .. write into f ... + yield finish_cb() + """ + + path = self._file_info_to_path(file_info) + fname = os.path.join(self.local_media_directory, path) + + dirname = os.path.dirname(fname) + if not os.path.exists(dirname): + os.makedirs(dirname) + + finished_called = [False] + + @defer.inlineCallbacks + def finish(): + for provider in self.storage_providers: + yield provider.store_file(path, file_info) + + finished_called[0] = True + + try: + with open(fname, "wb") as f: + yield f, fname, finish + except Exception: + t, v, tb = sys.exc_info() + try: + os.remove(fname) + except Exception: + pass + raise t, v, tb + + if not finished_called: + raise Exception("Finished callback not called") + + @defer.inlineCallbacks + def fetch_media(self, file_info): + """Attempts to fetch media described by file_info from the local cache + and configured storage providers. + + Args: + file_info (FileInfo) + + Returns: + Deferred[Responder|None]: Returns a Responder if the file was found, + otherwise None. + """ + + path = self._file_info_to_path(file_info) + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + defer.returnValue(FileResponder(open(local_path, "rb"))) + + for provider in self.storage_providers: + res = yield provider.fetch(path, file_info) + if res: + defer.returnValue(res) + + defer.returnValue(None) + + def _file_info_to_path(self, file_info): + """Converts file_info into a relative path. + + The path is suitable for storing files under a directory, e.g. used to + store files on local FS under the base media repository directory. + + Args: + file_info (FileInfo) + + Returns: + str + """ + if file_info.url_cache: + return self.filepaths.url_cache_filepath_rel(file_info.file_id) + + if file_info.server_name: + if file_info.thumbnail: + return self.filepaths.remote_media_thumbnail_rel( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail_width, + height=file_info.thumbnail_height, + content_type=file_info.thumbnail_type, + method=file_info.thumbnail_method + ) + return self.filepaths.remote_media_filepath_rel( + file_info.server_name, file_info.file_id, + ) + + if file_info.thumbnail: + return self.filepaths.local_media_thumbnail_rel( + media_id=file_info.file_id, + width=file_info.thumbnail_width, + height=file_info.thumbnail_height, + content_type=file_info.thumbnail_type, + method=file_info.thumbnail_method + ) + return self.filepaths.local_media_filepath_rel( + file_info.file_id, + ) + + +def _write_file_synchronously(source, fname): + """Write `source` to the path `fname` synchronously. Should be called + from a thread. + + Args: + source: A file like object to be written + fname (str): Path to write to + """ + dirname = os.path.dirname(fname) + if not os.path.exists(dirname): + os.makedirs(dirname) + + source.seek(0) # Ensure we read from the start of the file + with open(fname, "wb") as f: + shutil.copyfileobj(source, f) + + +class FileResponder(Responder): + """Wraps an open file that can be sent to a request. + + Args: + open_file (file): A file like object to be streamed ot the client, + is closed when finished streaming. + """ + def __init__(self, open_file): + self.open_file = open_file + + @defer.inlineCallbacks + def write_to_consumer(self, consumer): + yield FileSender().beginFileTransfer(self.open_file, consumer) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.open_file.close() diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 40d2e664eb..981f01e417 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -17,6 +17,8 @@ from twisted.web.server import NOT_DONE_YET from twisted.internet import defer from twisted.web.resource import Resource +from ._base import FileInfo + from synapse.api.errors import ( SynapseError, Codes, ) @@ -49,7 +51,7 @@ logger = logging.getLogger(__name__) class PreviewUrlResource(Resource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs, media_repo, media_storage): Resource.__init__(self) self.auth = hs.get_auth() @@ -62,6 +64,7 @@ class PreviewUrlResource(Resource): self.client = SpiderHttpClient(hs) self.media_repo = media_repo self.primary_base_path = media_repo.primary_base_path + self.media_storage = media_storage self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist @@ -182,8 +185,10 @@ class PreviewUrlResource(Resource): logger.debug("got media_info of '%s'" % media_info) if _is_media(media_info['media_type']): + file_id = media_info['filesystem_id'] dims = yield self.media_repo._generate_thumbnails( - None, media_info['filesystem_id'], media_info, url_cache=True, + None, file_id, file_id, media_info["media_type"], + url_cache=True, ) og = { @@ -228,8 +233,10 @@ class PreviewUrlResource(Resource): if _is_media(image_info['media_type']): # TODO: make sure we don't choke on white-on-transparent images + file_id = image_info['filesystem_id'] dims = yield self.media_repo._generate_thumbnails( - None, image_info['filesystem_id'], image_info, url_cache=True, + None, file_id, file_id, image_info["media_type"], + url_cache=True, ) if dims: og["og:image:width"] = dims['width'] @@ -273,19 +280,21 @@ class PreviewUrlResource(Resource): file_id = datetime.date.today().isoformat() + '_' + random_string(16) - fpath = self.filepaths.url_cache_filepath_rel(file_id) - fname = os.path.join(self.primary_base_path, fpath) - self.media_repo._makedirs(fname) + file_info = FileInfo( + server_name=None, + file_id=file_id, + url_cache=True, + ) try: - with open(fname, "wb") as f: + with self.media_storage.store_into_file(file_info) as (f, fname, finish): logger.debug("Trying to get url '%s'" % url) length, headers, uri, code = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size, ) # FIXME: pass through 404s and other error messages nicely - yield self.media_repo.copy_to_backup(fpath) + yield finish() media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() @@ -327,7 +336,6 @@ class PreviewUrlResource(Resource): ) except Exception as e: - os.remove(fname) raise SynapseError( 500, ("Failed to download content: %s" % e), Codes.UNKNOWN diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py new file mode 100644 index 0000000000..2ad602e101 --- /dev/null +++ b/synapse/rest/media/v1/storage_provider.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer, threads + +from .media_storage import FileResponder + +from synapse.util.logcontext import preserve_fn + +import logging +import os +import shutil + + +logger = logging.getLogger(__name__) + + +class StorageProvider(object): + """A storage provider is a service that can store uploaded media and + retrieve them. + """ + def store_file(self, path, file_info): + """Store the file described by file_info. The actual contents can be + retrieved by reading the file in file_info.upload_path. + + Args: + path (str): Relative path of file in local cache + file_info (FileInfo) + + Returns: + Deferred + """ + pass + + def fetch(self, path, file_info): + """Attempt to fetch the file described by file_info and stream it + into writer. + + Args: + path (str): Relative path of file in local cache + file_info (FileInfo) + + Returns: + Deferred(Responder): Returns a Responder if the provider has the file, + otherwise returns None. + """ + pass + + +class StorageProviderWrapper(StorageProvider): + """Wraps a storage provider and provides various config options + + Args: + backend (StorageProvider) + store (bool): Whether to store new files or not. + store_synchronous (bool): Whether to wait for file to be successfully + uploaded, or todo the upload in the backgroud. + store_remote (bool): Whether remote media should be uploaded + """ + def __init__(self, backend, store, store_synchronous, store_remote): + self.backend = backend + self.store = store + self.store_synchronous = store_synchronous + self.store_remote = store_remote + + def store_file(self, path, file_info): + if not self.store: + return defer.succeed(None) + + if file_info.server_name and not self.store_remote: + return defer.succeed(None) + + if self.store_synchronous: + return self.backend.store_file(path, file_info) + else: + # TODO: Handle errors. + preserve_fn(self.backend.store_file)(path, file_info) + return defer.succeed(None) + + def fetch(self, path, file_info): + return self.backend.fetch(path, file_info) + + +class FileStorageProviderBackend(StorageProvider): + """A storage provider that stores files in a directory on a filesystem. + + Args: + cache_directory (str): Base path of the local media repository + base_directory (str): Base path to store new files + """ + + def __init__(self, cache_directory, base_directory): + self.cache_directory = cache_directory + self.base_directory = base_directory + + def store_file(self, path, file_info): + """See StorageProvider.store_file""" + + primary_fname = os.path.join(self.cache_directory, path) + backup_fname = os.path.join(self.base_directory, path) + + dirname = os.path.dirname(backup_fname) + if not os.path.exists(dirname): + os.makedirs(dirname) + + return threads.deferToThread( + shutil.copyfile, primary_fname, backup_fname, + ) + + def fetch(self, path, file_info): + """See StorageProvider.fetch""" + + backup_fname = os.path.join(self.base_directory, path) + if os.path.isfile(backup_fname): + return FileResponder(open(backup_fname, "rb")) diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 53e48aba22..c09f2dec41 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -14,7 +14,10 @@ # limitations under the License. -from ._base import parse_media_id, respond_404, respond_with_file +from ._base import ( + parse_media_id, respond_404, respond_with_file, FileInfo, + respond_with_responder, +) from twisted.web.resource import Resource from synapse.http.servlet import parse_string, parse_integer from synapse.http.server import request_handler, set_cors_headers @@ -30,12 +33,12 @@ logger = logging.getLogger(__name__) class ThumbnailResource(Resource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs, media_repo, media_storage): Resource.__init__(self) self.store = hs.get_datastore() - self.filepaths = media_repo.filepaths self.media_repo = media_repo + self.media_storage = media_storage self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.server_name = hs.hostname self.version_string = hs.version_string @@ -64,6 +67,7 @@ class ThumbnailResource(Resource): yield self._respond_local_thumbnail( request, media_id, width, height, method, m_type ) + self.media_repo.mark_recently_accessed(server_name, media_id) else: if self.dynamic_thumbnails: yield self._select_or_generate_remote_thumbnail( @@ -75,13 +79,18 @@ class ThumbnailResource(Resource): request, server_name, media_id, width, height, method, m_type ) + self.media_repo.mark_recently_accessed(None, media_id) @defer.inlineCallbacks def _respond_local_thumbnail(self, request, media_id, width, height, method, m_type): media_info = yield self.store.get_local_media(media_id) - if not media_info or media_info["quarantined_by"]: + if not media_info: + respond_404(request) + return + if media_info["quarantined_by"]: + logger.info("Media is quarantined") respond_404(request) return @@ -91,24 +100,24 @@ class ThumbnailResource(Resource): thumbnail_info = self._select_thumbnail( width, height, method, m_type, thumbnail_infos ) - t_width = thumbnail_info["thumbnail_width"] - t_height = thumbnail_info["thumbnail_height"] - t_type = thumbnail_info["thumbnail_type"] - t_method = thumbnail_info["thumbnail_method"] - - if media_info["url_cache"]: - # TODO: Check the file still exists, if it doesn't we can redownload - # it from the url `media_info["url_cache"]` - file_path = self.filepaths.url_cache_thumbnail( - media_id, t_width, t_height, t_type, t_method, - ) - else: - file_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method, - ) - yield respond_with_file(request, t_type, file_path) + file_info = FileInfo( + server_name=None, file_id=media_id, + url_cache=media_info["url_cache"], + thumbnail=True, + thumbnail_width=thumbnail_info["thumbnail_width"], + thumbnail_height=thumbnail_info["thumbnail_height"], + thumbnail_type=thumbnail_info["thumbnail_type"], + thumbnail_method=thumbnail_info["thumbnail_method"], + ) + + t_type = file_info.thumbnail_type + t_length = thumbnail_info["thumbnail_length"] + + responder = yield self.media_storage.fetch_media(file_info) + yield respond_with_responder(request, responder, t_type, t_length) else: + logger.info("Couldn't find any generated thumbnails") respond_404(request) @defer.inlineCallbacks @@ -117,7 +126,11 @@ class ThumbnailResource(Resource): desired_type): media_info = yield self.store.get_local_media(media_id) - if not media_info or media_info["quarantined_by"]: + if not media_info: + respond_404(request) + return + if media_info["quarantined_by"]: + logger.info("Media is quarantined") respond_404(request) return @@ -129,22 +142,25 @@ class ThumbnailResource(Resource): t_type = info["thumbnail_type"] == desired_type if t_w and t_h and t_method and t_type: - if media_info["url_cache"]: - # TODO: Check the file still exists, if it doesn't we can redownload - # it from the url `media_info["url_cache"]` - file_path = self.filepaths.url_cache_thumbnail( - media_id, desired_width, desired_height, desired_type, - desired_method, - ) - else: - file_path = self.filepaths.local_media_thumbnail( - media_id, desired_width, desired_height, desired_type, - desired_method, - ) - yield respond_with_file(request, desired_type, file_path) - return - - logger.debug("We don't have a local thumbnail of that size. Generating") + file_info = FileInfo( + server_name=None, file_id=media_id, + url_cache=media_info["url_cache"], + thumbnail=True, + thumbnail_width=info["thumbnail_width"], + thumbnail_height=info["thumbnail_height"], + thumbnail_type=info["thumbnail_type"], + thumbnail_method=info["thumbnail_method"], + ) + + t_type = file_info.thumbnail_type + t_length = info["thumbnail_length"] + + responder = yield self.media_storage.fetch_media(file_info) + if responder: + yield respond_with_responder(request, responder, t_type, t_length) + return + + logger.debug("We don't have a thumbnail of that size. Generating") # Okay, so we generate one. file_path = yield self.media_repo.generate_local_exact_thumbnail( @@ -154,13 +170,14 @@ class ThumbnailResource(Resource): if file_path: yield respond_with_file(request, desired_type, file_path) else: + logger.warn("Failed to generate thumbnail") respond_404(request) @defer.inlineCallbacks def _select_or_generate_remote_thumbnail(self, request, server_name, media_id, desired_width, desired_height, desired_method, desired_type): - media_info = yield self.media_repo.get_remote_media(server_name, media_id) + media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = yield self.store.get_remote_media_thumbnails( server_name, media_id, @@ -175,14 +192,24 @@ class ThumbnailResource(Resource): t_type = info["thumbnail_type"] == desired_type if t_w and t_h and t_method and t_type: - file_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, desired_width, desired_height, - desired_type, desired_method, + file_info = FileInfo( + server_name=server_name, file_id=media_info["filesystem_id"], + thumbnail=True, + thumbnail_width=info["thumbnail_width"], + thumbnail_height=info["thumbnail_height"], + thumbnail_type=info["thumbnail_type"], + thumbnail_method=info["thumbnail_method"], ) - yield respond_with_file(request, desired_type, file_path) - return - logger.debug("We don't have a local thumbnail of that size. Generating") + t_type = file_info.thumbnail_type + t_length = info["thumbnail_length"] + + responder = yield self.media_storage.fetch_media(file_info) + if responder: + yield respond_with_responder(request, responder, t_type, t_length) + return + + logger.debug("We don't have a thumbnail of that size. Generating") # Okay, so we generate one. file_path = yield self.media_repo.generate_remote_exact_thumbnail( @@ -193,6 +220,7 @@ class ThumbnailResource(Resource): if file_path: yield respond_with_file(request, desired_type, file_path) else: + logger.warn("Failed to generate thumbnail") respond_404(request) @defer.inlineCallbacks @@ -201,7 +229,7 @@ class ThumbnailResource(Resource): # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. - yield self.media_repo.get_remote_media(server_name, media_id) + media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = yield self.store.get_remote_media_thumbnails( server_name, media_id, @@ -211,18 +239,22 @@ class ThumbnailResource(Resource): thumbnail_info = self._select_thumbnail( width, height, method, m_type, thumbnail_infos ) - t_width = thumbnail_info["thumbnail_width"] - t_height = thumbnail_info["thumbnail_height"] - t_type = thumbnail_info["thumbnail_type"] - t_method = thumbnail_info["thumbnail_method"] - file_id = thumbnail_info["filesystem_id"] + file_info = FileInfo( + server_name=server_name, file_id=media_info["filesystem_id"], + thumbnail=True, + thumbnail_width=thumbnail_info["thumbnail_width"], + thumbnail_height=thumbnail_info["thumbnail_height"], + thumbnail_type=thumbnail_info["thumbnail_type"], + thumbnail_method=thumbnail_info["thumbnail_method"], + ) + + t_type = file_info.thumbnail_type t_length = thumbnail_info["thumbnail_length"] - file_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method, - ) - yield respond_with_file(request, t_type, file_path, t_length) + responder = yield self.media_storage.fetch_media(file_info) + yield respond_with_responder(request, responder, t_type, t_length) else: + logger.info("Failed to find any generated thumbnails") respond_404(request) def _select_thumbnail(self, desired_width, desired_height, desired_method, diff --git a/synapse/state.py b/synapse/state.py index 9e624b4937..1f9abf9d3d 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -341,7 +341,7 @@ class StateHandler(object): if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events( + new_state = yield resolve_events_with_factory( state_groups_ids.values(), state_map_factory=lambda ev_ids: self.store.get_events( ev_ids, get_prev_content=False, check_redacted=False, @@ -404,7 +404,7 @@ class StateHandler(object): } with Measure(self.clock, "state._resolve_events"): - new_state = resolve_events(state_set_ids, state_map) + new_state = resolve_events_with_state_map(state_set_ids, state_map) new_state = { key: state_map[ev_id] for key, ev_id in new_state.items() @@ -420,19 +420,17 @@ def _ordered_events(events): return sorted(events, key=key_func) -def resolve_events(state_sets, state_map_factory): +def resolve_events_with_state_map(state_sets, state_map): """ Args: state_sets(list): List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - state_map_factory(dict|callable): If callable, then will be called - with a list of event_ids that are needed, and should return with - a Deferred of dict of event_id to event. Otherwise, should be - a dict from event_id to event of all events in state_sets. + state_map(dict): a dict from event_id to event, for all events in + state_sets. Returns - dict[(str, str), synapse.events.FrozenEvent] is a map from - (type, state_key) to event. + dict[(str, str), synapse.events.FrozenEvent]: + a map from (type, state_key) to event. """ if len(state_sets) == 1: return state_sets[0] @@ -441,13 +439,6 @@ def resolve_events(state_sets, state_map_factory): state_sets, ) - if callable(state_map_factory): - return _resolve_with_state_fac( - unconflicted_state, conflicted_state, state_map_factory - ) - - state_map = state_map_factory - auth_events = _create_auth_events_from_maps( unconflicted_state, conflicted_state, state_map ) @@ -491,8 +482,26 @@ def _seperate(state_sets): @defer.inlineCallbacks -def _resolve_with_state_fac(unconflicted_state, conflicted_state, - state_map_factory): +def resolve_events_with_factory(state_sets, state_map_factory): + """ + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + state_map_factory(func): will be called + with a list of event_ids that are needed, and should return with + a Deferred of dict of event_id to event. + + Returns + Deferred[dict[(str, str), synapse.events.FrozenEvent]]: + a map from (type, state_key) to event. + """ + if len(state_sets) == 1: + defer.returnValue(state_sets[0]) + + unconflicted_state, conflicted_state = _seperate( + state_sets, + ) + needed_events = set( event_id for event_ids in conflicted_state.itervalues() diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b971f0cb18..68125006eb 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -291,33 +291,33 @@ class SQLBaseStore(object): @defer.inlineCallbacks def runInteraction(self, desc, func, *args, **kwargs): - """Wraps the .runInteraction() method on the underlying db_pool.""" - current_context = LoggingContext.current_context() + """Starts a transaction on the database and runs a given function - start_time = time.time() * 1000 + Arguments: + desc (str): description of the transaction, for logging and metrics + func (func): callback function, which will be called with a + database transaction (twisted.enterprise.adbapi.Transaction) as + its first argument, followed by `args` and `kwargs`. + + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ + current_context = LoggingContext.current_context() after_callbacks = [] final_callbacks = [] def inner_func(conn, *args, **kwargs): - with LoggingContext("runInteraction") as context: - sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) - - if self.database_engine.is_connection_closed(conn): - logger.debug("Reconnecting closed database connection") - conn.reconnect() - - current_context.copy_to(context) - return self._new_transaction( - conn, desc, after_callbacks, final_callbacks, current_context, - func, *args, **kwargs - ) + return self._new_transaction( + conn, desc, after_callbacks, final_callbacks, current_context, + func, *args, **kwargs + ) try: - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection( - inner_func, *args, **kwargs - ) + result = yield self.runWithConnection(inner_func, *args, **kwargs) for after_callback, after_args, after_kwargs in after_callbacks: after_callback(*after_args, **after_kwargs) @@ -329,14 +329,27 @@ class SQLBaseStore(object): @defer.inlineCallbacks def runWithConnection(self, func, *args, **kwargs): - """Wraps the .runInteraction() method on the underlying db_pool.""" + """Wraps the .runWithConnection() method on the underlying db_pool. + + Arguments: + func (func): callback function, which will be called with a + database connection (twisted.enterprise.adbapi.Connection) as + its first argument, followed by `args` and `kwargs`. + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ current_context = LoggingContext.current_context() start_time = time.time() * 1000 def inner_func(conn, *args, **kwargs): with LoggingContext("runWithConnection") as context: - sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) + sched_duration_ms = time.time() * 1000 - start_time + sql_scheduling_timer.inc_by(sched_duration_ms) + current_context.add_database_scheduled(sched_duration_ms) if self.database_engine.is_connection_closed(conn): logger.debug("Reconnecting closed database connection") diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ba0da83642..7a9cd3ec90 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -27,7 +27,7 @@ from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError -from synapse.state import resolve_events +from synapse.state import resolve_events_with_factory from synapse.util.caches.descriptors import cached from synapse.types import get_domain_from_id @@ -146,6 +146,9 @@ class _EventPeristenceQueue(object): try: queue = self._get_drainining_queue(room_id) for item in queue: + # handle_queue_loop runs in the sentinel logcontext, so + # there is no need to preserve_fn when running the + # callbacks on the deferred. try: ret = yield per_item_callback(item) item.deferred.callback(ret) @@ -157,7 +160,11 @@ class _EventPeristenceQueue(object): self._event_persist_queues[room_id] = queue self._currently_persisting_rooms.discard(room_id) - preserve_fn(handle_queue_loop)() + # set handle_queue_loop off on the background. We don't want to + # attribute work done in it to the current request, so we drop the + # logcontext altogether. + with PreserveLoggingContext(): + handle_queue_loop() def _get_drainining_queue(self, room_id): queue = self._event_persist_queues.setdefault(room_id, deque()) @@ -556,7 +563,7 @@ class EventsStore(SQLBaseStore): to_return.update(evs) defer.returnValue(to_return) - current_state = yield resolve_events( + current_state = yield resolve_events_with_factory( state_sets, state_map_factory=get_events, ) diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 6ebc372498..e6cdbb0545 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -173,7 +173,14 @@ class MediaRepositoryStore(BackgroundUpdateStore): desc="store_cached_remote_media", ) - def update_cached_last_access_time(self, origin_id_tuples, time_ts): + def update_cached_last_access_time(self, local_media, remote_media, time_ms): + """Updates the last access time of the given media + + Args: + local_media (iterable[str]): Set of media_ids + remote_media (iterable[(str, str)]): Set of (server_name, media_id) + time_ms: Current time in milliseconds + """ def update_cache_txn(txn): sql = ( "UPDATE remote_media_cache SET last_access_ts = ?" @@ -181,8 +188,18 @@ class MediaRepositoryStore(BackgroundUpdateStore): ) txn.executemany(sql, ( - (time_ts, media_origin, media_id) - for media_origin, media_id in origin_id_tuples + (time_ms, media_origin, media_id) + for media_origin, media_id in remote_media + )) + + sql = ( + "UPDATE local_media_repository SET last_access_ts = ?" + " WHERE media_id = ?" + ) + + txn.executemany(sql, ( + (time_ms, media_id) + for media_id in local_media )) return self.runInteraction("update_cached_last_access_time", update_cache_txn) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index d1691bbac2..c845a0cec5 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 46 +SCHEMA_VERSION = 47 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/schema/delta/47/last_access_media.sql new file mode 100644 index 0000000000..f505fb22b5 --- /dev/null +++ b/synapse/storage/schema/delta/47/last_access_media.sql @@ -0,0 +1,16 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT; diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index c9bff408ef..f150ef0103 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -641,8 +641,13 @@ class UserDirectoryStore(SQLBaseStore): """ if self.hs.config.user_directory_search_all_users: - join_clause = "" - where_clause = "?<>''" # naughty hack to keep the same number of binds + # dummy to keep the number of binds & aliases the same + join_clause = """ + LEFT JOIN ( + SELECT NULL as user_id WHERE NULL = ? + ) AS s USING (user_id)" + """ + where_clause = "" else: join_clause = """ LEFT JOIN users_in_public_rooms AS p USING (user_id) diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 48c9f6802d..94fa7cac98 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -52,13 +52,17 @@ except Exception: class LoggingContext(object): """Additional context for log formatting. Contexts are scoped within a "with" block. + Args: name (str): Name for the context for debugging. """ __slots__ = [ - "previous_context", "name", "usage_start", "usage_end", "main_thread", - "__dict__", "tag", "alive", + "previous_context", "name", "ru_stime", "ru_utime", + "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms", + "usage_start", "usage_end", + "main_thread", "alive", + "request", "tag", ] thread_local = threading.local() @@ -83,6 +87,9 @@ class LoggingContext(object): def add_database_transaction(self, duration_ms): pass + def add_database_scheduled(self, sched_ms): + pass + def __nonzero__(self): return False @@ -94,9 +101,17 @@ class LoggingContext(object): self.ru_stime = 0. self.ru_utime = 0. self.db_txn_count = 0 - self.db_txn_duration = 0. + + # ms spent waiting for db txns, excluding scheduling time + self.db_txn_duration_ms = 0 + + # ms spent waiting for db txns to be scheduled + self.db_sched_duration_ms = 0 + self.usage_start = None + self.usage_end = None self.main_thread = threading.current_thread() + self.request = None self.tag = "" self.alive = True @@ -105,7 +120,11 @@ class LoggingContext(object): @classmethod def current_context(cls): - """Get the current logging context from thread local storage""" + """Get the current logging context from thread local storage + + Returns: + LoggingContext: the current logging context + """ return getattr(cls.thread_local, "current_context", cls.sentinel) @classmethod @@ -155,11 +174,13 @@ class LoggingContext(object): self.alive = False def copy_to(self, record): - """Copy fields from this context to the record""" - for key, value in self.__dict__.items(): - setattr(record, key, value) + """Copy logging fields from this context to a log record or + another LoggingContext + """ - record.ru_utime, record.ru_stime = self.get_resource_usage() + # 'request' is the only field we currently use in the logger, so that's + # all we need to copy + record.request = self.request def start(self): if threading.current_thread() is not self.main_thread: @@ -194,7 +215,16 @@ class LoggingContext(object): def add_database_transaction(self, duration_ms): self.db_txn_count += 1 - self.db_txn_duration += duration_ms / 1000. + self.db_txn_duration_ms += duration_ms + + def add_database_scheduled(self, sched_ms): + """Record a use of the database pool + + Args: + sched_ms (int): number of milliseconds it took us to get a + connection + """ + self.db_sched_duration_ms += sched_ms class LoggingContextFilter(logging.Filter): diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 4ea930d3e8..059bb7fedf 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -27,25 +27,62 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) -block_timer = metrics.register_distribution( - "block_timer", - labels=["block_name"] +# total number of times we have hit this block +response_count = metrics.register_counter( + "block_count", + labels=["block_name"], + alternative_names=( + # the following are all deprecated aliases for the same metric + metrics.name_prefix + x for x in ( + "_block_timer:count", + "_block_ru_utime:count", + "_block_ru_stime:count", + "_block_db_txn_count:count", + "_block_db_txn_duration:count", + ) + ) +) + +block_timer = metrics.register_counter( + "block_time_seconds", + labels=["block_name"], + alternative_names=( + metrics.name_prefix + "_block_timer:total", + ), ) -block_ru_utime = metrics.register_distribution( - "block_ru_utime", labels=["block_name"] +block_ru_utime = metrics.register_counter( + "block_ru_utime_seconds", labels=["block_name"], + alternative_names=( + metrics.name_prefix + "_block_ru_utime:total", + ), ) -block_ru_stime = metrics.register_distribution( - "block_ru_stime", labels=["block_name"] +block_ru_stime = metrics.register_counter( + "block_ru_stime_seconds", labels=["block_name"], + alternative_names=( + metrics.name_prefix + "_block_ru_stime:total", + ), ) -block_db_txn_count = metrics.register_distribution( - "block_db_txn_count", labels=["block_name"] +block_db_txn_count = metrics.register_counter( + "block_db_txn_count", labels=["block_name"], + alternative_names=( + metrics.name_prefix + "_block_db_txn_count:total", + ), ) -block_db_txn_duration = metrics.register_distribution( - "block_db_txn_duration", labels=["block_name"] +# seconds spent waiting for db txns, excluding scheduling time, in this block +block_db_txn_duration = metrics.register_counter( + "block_db_txn_duration_seconds", labels=["block_name"], + alternative_names=( + metrics.name_prefix + "_block_db_txn_count:total", + ), +) + +# seconds spent waiting for a db connection, in this block +block_db_sched_duration = metrics.register_counter( + "block_db_sched_duration_seconds", labels=["block_name"], ) @@ -64,7 +101,9 @@ def measure_func(name): class Measure(object): __slots__ = [ "clock", "name", "start_context", "start", "new_context", "ru_utime", - "ru_stime", "db_txn_count", "db_txn_duration", "created_context" + "ru_stime", + "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms", + "created_context", ] def __init__(self, clock, name): @@ -84,7 +123,8 @@ class Measure(object): self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() self.db_txn_count = self.start_context.db_txn_count - self.db_txn_duration = self.start_context.db_txn_duration + self.db_txn_duration_ms = self.start_context.db_txn_duration_ms + self.db_sched_duration_ms = self.start_context.db_sched_duration_ms def __exit__(self, exc_type, exc_val, exc_tb): if isinstance(exc_type, Exception) or not self.start_context: @@ -114,7 +154,12 @@ class Measure(object): context.db_txn_count - self.db_txn_count, self.name ) block_db_txn_duration.inc_by( - context.db_txn_duration - self.db_txn_duration, self.name + (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000., + self.name + ) + block_db_sched_duration.inc_by( + (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000., + self.name ) if self.created_context: |