diff --git a/CHANGES.rst b/CHANGES.rst
index 24e4e7a384..a7ed49e105 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -2,7 +2,10 @@ Unreleased
==========
synctl no longer starts the main synapse when using ``-a`` option with workers.
-A new worker file should be added with ``worker_app: synapse.app.homeserver``
+A new worker file should be added with ``worker_app: synapse.app.homeserver``.
+
+This release also begins the process of renaming a number of the metrics
+reported to prometheus. See `docs/metrics-howto.rst <docs/metrics-howto.rst#block-and-response-metrics-renamed-for-0-27-0>`_.
Changes in synapse v0.26.0 (2018-01-05)
diff --git a/docs/metrics-howto.rst b/docs/metrics-howto.rst
index 143cd0f42f..8acc479bc3 100644
--- a/docs/metrics-howto.rst
+++ b/docs/metrics-howto.rst
@@ -16,7 +16,7 @@ How to monitor Synapse metrics using Prometheus
metrics_port: 9092
Also ensure that ``enable_metrics`` is set to ``True``.
-
+
Restart synapse.
3. Add a prometheus target for synapse.
@@ -28,11 +28,58 @@ How to monitor Synapse metrics using Prometheus
static_configs:
- targets: ["my.server.here:9092"]
- If your prometheus is older than 1.5.2, you will need to replace
+ If your prometheus is older than 1.5.2, you will need to replace
``static_configs`` in the above with ``target_groups``.
-
+
Restart prometheus.
+
+Block and response metrics renamed for 0.27.0
+---------------------------------------------
+
+Synapse 0.27.0 begins the process of rationalising the duplicate ``*:count``
+metrics reported for the resource tracking for code blocks and HTTP requests.
+
+At the same time, the corresponding ``*:total`` metrics are being renamed, as
+the ``:total`` suffix no longer makes sense in the absence of a corresponding
+``:count`` metric.
+
+To enable a graceful migration path, this release just adds new names for the
+metrics being renamed. A future release will remove the old ones.
+
+The following table shows the new metrics, and the old metrics which they are
+replacing.
+
+==================================================== ===================================================
+New name Old name
+==================================================== ===================================================
+synapse_util_metrics_block_count synapse_util_metrics_block_timer:count
+synapse_util_metrics_block_count synapse_util_metrics_block_ru_utime:count
+synapse_util_metrics_block_count synapse_util_metrics_block_ru_stime:count
+synapse_util_metrics_block_count synapse_util_metrics_block_db_txn_count:count
+synapse_util_metrics_block_count synapse_util_metrics_block_db_txn_duration:count
+
+synapse_util_metrics_block_time_seconds synapse_util_metrics_block_timer:total
+synapse_util_metrics_block_ru_utime_seconds synapse_util_metrics_block_ru_utime:total
+synapse_util_metrics_block_ru_stime_seconds synapse_util_metrics_block_ru_stime:total
+synapse_util_metrics_block_db_txn_count synapse_util_metrics_block_db_txn_count:total
+synapse_util_metrics_block_db_txn_duration_seconds synapse_util_metrics_block_db_txn_duration:total
+
+synapse_http_server_response_count synapse_http_server_requests
+synapse_http_server_response_count synapse_http_server_response_time:count
+synapse_http_server_response_count synapse_http_server_response_ru_utime:count
+synapse_http_server_response_count synapse_http_server_response_ru_stime:count
+synapse_http_server_response_count synapse_http_server_response_db_txn_count:count
+synapse_http_server_response_count synapse_http_server_response_db_txn_duration:count
+
+synapse_http_server_response_time_seconds synapse_http_server_response_time:total
+synapse_http_server_response_ru_utime_seconds synapse_http_server_response_ru_utime:total
+synapse_http_server_response_ru_stime_seconds synapse_http_server_response_ru_stime:total
+synapse_http_server_response_db_txn_count synapse_http_server_response_db_txn_count:total
+synapse_http_server_response_db_txn_duration_seconds synapse_http_server_response_db_txn_duration:total
+==================================================== ===================================================
+
+
Standard Metric Names
---------------------
@@ -42,7 +89,7 @@ have been changed to seconds, from miliseconds.
================================== =============================
New name Old name
----------------------------------- -----------------------------
+================================== =============================
process_cpu_user_seconds_total process_resource_utime / 1000
process_cpu_system_seconds_total process_resource_stime / 1000
process_open_fds (no 'type' label) process_fds
@@ -52,8 +99,8 @@ The python-specific counts of garbage collector performance have been renamed.
=========================== ======================
New name Old name
---------------------------- ----------------------
-python_gc_time reactor_gc_time
+=========================== ======================
+python_gc_time reactor_gc_time
python_gc_unreachable_total reactor_gc_unreachable
python_gc_counts reactor_gc_counts
=========================== ======================
@@ -62,7 +109,7 @@ The twisted-specific reactor metrics have been renamed.
==================================== =====================
New name Old name
------------------------------------- ---------------------
+==================================== =====================
python_twisted_reactor_pending_calls reactor_pending_calls
python_twisted_reactor_tick_time reactor_tick_time
==================================== =====================
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 4748f71c2f..29eb012ddb 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -96,7 +96,7 @@ class TlsConfig(Config):
# certificates returned by this server match one of the fingerprints.
#
# Synapse automatically adds the fingerprint of its own certificate
- # to the list. So if federation traffic is handle directly by synapse
+ # to the list. So if federation traffic is handled directly by synapse
# then no modification to the list is required.
#
# If synapse is run behind a load balancer that handles the TLS then it
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 3e7809b04f..9d39f46583 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -42,6 +42,8 @@ sent_edus_counter = client_metrics.register_counter("sent_edus")
sent_transactions_counter = client_metrics.register_counter("sent_transactions")
+events_processed_counter = client_metrics.register_counter("events_processed")
+
class TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
@@ -205,6 +207,8 @@ class TransactionQueue(object):
self._send_pdu(event, destinations)
+ events_processed_counter.inc_by(len(events))
+
yield self.store.update_federation_out_pos(
"events", next_token
)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index feca3e4c10..3dd3fa2a27 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -15,6 +15,7 @@
from twisted.internet import defer
+import synapse
from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
@@ -23,6 +24,10 @@ import logging
logger = logging.getLogger(__name__)
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+events_processed_counter = metrics.register_counter("events_processed")
+
def log_failure(failure):
logger.error(
@@ -103,6 +108,8 @@ class ApplicationServicesHandler(object):
service, event
)
+ events_processed_counter.inc_by(len(events))
+
yield self.store.set_appservice_last_pos(upper_bound)
finally:
self.is_processing = False
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index f7fad15c62..d996aa90bb 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -17,7 +17,8 @@ import logging
from twisted.internet import defer
-from synapse.types import get_domain_from_id
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id, UserID
from synapse.util.stringutils import random_string
@@ -33,7 +34,7 @@ class DeviceMessageHandler(object):
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
- self.is_mine_id = hs.is_mine_id
+ self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler(
@@ -52,6 +53,12 @@ class DeviceMessageHandler(object):
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
+ # we use UserID.from_string to catch invalid user ids
+ if not self.is_mine(UserID.from_string(user_id)):
+ logger.warning("Request for keys for non-local user %s",
+ user_id)
+ raise SynapseError(400, "Not a user here")
+
messages_by_device = {
device_id: {
"content": message_content,
@@ -77,7 +84,8 @@ class DeviceMessageHandler(object):
local_messages = {}
remote_messages = {}
for user_id, by_device in messages.items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
messages_by_device = {
device_id: {
"content": message_content,
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 668a90e495..5af8abf66b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -20,7 +20,7 @@ from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException
-from synapse.types import get_domain_from_id
+from synapse.types import get_domain_from_id, UserID
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination
@@ -32,7 +32,7 @@ class E2eKeysHandler(object):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.device_handler = hs.get_device_handler()
- self.is_mine_id = hs.is_mine_id
+ self.is_mine = hs.is_mine
self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the
@@ -70,7 +70,8 @@ class E2eKeysHandler(object):
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
else:
remote_queries[user_id] = device_ids
@@ -170,7 +171,8 @@ class E2eKeysHandler(object):
result_dict = {}
for user_id, device_ids in query.items():
- if not self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s",
user_id)
raise SynapseError(400, "Not a user here")
@@ -213,7 +215,8 @@ class E2eKeysHandler(object):
remote_queries = {}
for user_id, device_keys in query.get("one_time_keys", {}).items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 6e8f4c9c5f..165c684d0d 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -42,36 +42,70 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
-incoming_requests_counter = metrics.register_counter(
- "requests",
+# total number of responses served, split by method/servlet/tag
+response_count = metrics.register_counter(
+ "response_count",
labels=["method", "servlet", "tag"],
+ alternative_names=(
+ # the following are all deprecated aliases for the same metric
+ metrics.name_prefix + x for x in (
+ "_requests",
+ "_response_time:count",
+ "_response_ru_utime:count",
+ "_response_ru_stime:count",
+ "_response_db_txn_count:count",
+ "_response_db_txn_duration:count",
+ )
+ )
)
+
outgoing_responses_counter = metrics.register_counter(
"responses",
labels=["method", "code"],
)
-response_timer = metrics.register_distribution(
- "response_time",
- labels=["method", "servlet", "tag"]
+response_timer = metrics.register_counter(
+ "response_time_seconds",
+ labels=["method", "servlet", "tag"],
+ alternative_names=(
+ metrics.name_prefix + "_response_time:total",
+ ),
)
-response_ru_utime = metrics.register_distribution(
- "response_ru_utime", labels=["method", "servlet", "tag"]
+response_ru_utime = metrics.register_counter(
+ "response_ru_utime_seconds", labels=["method", "servlet", "tag"],
+ alternative_names=(
+ metrics.name_prefix + "_response_ru_utime:total",
+ ),
)
-response_ru_stime = metrics.register_distribution(
- "response_ru_stime", labels=["method", "servlet", "tag"]
+response_ru_stime = metrics.register_counter(
+ "response_ru_stime_seconds", labels=["method", "servlet", "tag"],
+ alternative_names=(
+ metrics.name_prefix + "_response_ru_stime:total",
+ ),
)
-response_db_txn_count = metrics.register_distribution(
- "response_db_txn_count", labels=["method", "servlet", "tag"]
+response_db_txn_count = metrics.register_counter(
+ "response_db_txn_count", labels=["method", "servlet", "tag"],
+ alternative_names=(
+ metrics.name_prefix + "_response_db_txn_count:total",
+ ),
)
-response_db_txn_duration = metrics.register_distribution(
- "response_db_txn_duration", labels=["method", "servlet", "tag"]
+# seconds spent waiting for db txns, excluding scheduling time, when processing
+# this request
+response_db_txn_duration = metrics.register_counter(
+ "response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
+ alternative_names=(
+ metrics.name_prefix + "_response_db_txn_duration:total",
+ ),
)
+# seconds spent waiting for a db connection, when processing this request
+response_db_sched_duration = metrics.register_counter(
+ "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
+)
_next_request_id = 0
@@ -288,15 +322,6 @@ class JsonResource(HttpServer, resource.Resource):
def _send_response(self, request, code, response_json_object,
response_code_message=None):
- # could alternatively use request.notifyFinish() and flip a flag when
- # the Deferred fires, but since the flag is RIGHT THERE it seems like
- # a waste.
- if request._disconnected:
- logger.warn(
- "Not sending response to request %s, already disconnected.",
- request)
- return
-
outgoing_responses_counter.inc(request.method, str(code))
# TODO: Only enable CORS for the requests that need it.
@@ -330,7 +355,7 @@ class RequestMetrics(object):
)
return
- incoming_requests_counter.inc(request.method, self.name, tag)
+ response_count.inc(request.method, self.name, tag)
response_timer.inc_by(
clock.time_msec() - self.start, request.method,
@@ -349,7 +374,10 @@ class RequestMetrics(object):
context.db_txn_count, request.method, self.name, tag
)
response_db_txn_duration.inc_by(
- context.db_txn_duration, request.method, self.name, tag
+ context.db_txn_duration_ms / 1000., request.method, self.name, tag
+ )
+ response_db_sched_duration.inc_by(
+ context.db_sched_duration_ms / 1000., request.method, self.name, tag
)
@@ -372,6 +400,15 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False,
version_string="", canonical_json=True):
+ # could alternatively use request.notifyFinish() and flip a flag when
+ # the Deferred fires, but since the flag is RIGHT THERE it seems like
+ # a waste.
+ if request._disconnected:
+ logger.warn(
+ "Not sending response to request %s, already disconnected.",
+ request)
+ return
+
if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n"
else:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index cd1492b1c3..e422c8dfae 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -66,14 +66,15 @@ class SynapseRequest(Request):
context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage()
db_txn_count = context.db_txn_count
- db_txn_duration = context.db_txn_duration
+ db_txn_duration_ms = context.db_txn_duration_ms
+ db_sched_duration_ms = context.db_sched_duration_ms
except Exception:
ru_utime, ru_stime = (0, 0)
- db_txn_count, db_txn_duration = (0, 0)
+ db_txn_count, db_txn_duration_ms = (0, 0)
self.site.access_logger.info(
"%s - %s - {%s}"
- " Processed request: %dms (%dms, %dms) (%dms/%d)"
+ " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
" %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
@@ -81,7 +82,8 @@ class SynapseRequest(Request):
int(time.time() * 1000) - self.start_time,
int(ru_utime * 1000),
int(ru_stime * 1000),
- int(db_txn_duration * 1000),
+ db_sched_duration_ms,
+ db_txn_duration_ms,
int(db_txn_count),
self.sentLength,
self.code,
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
index e87b2b80a7..f480aae614 100644
--- a/synapse/metrics/metric.py
+++ b/synapse/metrics/metric.py
@@ -17,16 +17,33 @@
from itertools import chain
-# TODO(paul): I can't believe Python doesn't have one of these
-def map_concat(func, items):
- # flatten a list-of-lists
- return list(chain.from_iterable(map(func, items)))
+def flatten(items):
+ """Flatten a list of lists
+
+ Args:
+ items: iterable[iterable[X]]
+
+ Returns:
+ list[X]: flattened list
+ """
+ return list(chain.from_iterable(items))
class BaseMetric(object):
+ """Base class for metrics which report a single value per label set
+ """
- def __init__(self, name, labels=[]):
- self.name = name
+ def __init__(self, name, labels=[], alternative_names=[]):
+ """
+ Args:
+ name (str): principal name for this metric
+ labels (list(str)): names of the labels which will be reported
+ for this metric
+ alternative_names (iterable(str)): list of alternative names for
+ this metric. This can be useful to provide a migration path
+ when renaming metrics.
+ """
+ self._names = [name] + list(alternative_names)
self.labels = labels # OK not to clone as we never write it
def dimension(self):
@@ -36,7 +53,7 @@ class BaseMetric(object):
return not len(self.labels)
def _render_labelvalue(self, value):
- # TODO: some kind of value escape
+ # TODO: escape backslashes, quotes and newlines
return '"%s"' % (value)
def _render_key(self, values):
@@ -47,19 +64,60 @@ class BaseMetric(object):
for k, v in zip(self.labels, values)])
)
+ def _render_for_labels(self, label_values, value):
+ """Render this metric for a single set of labels
+
+ Args:
+ label_values (list[str]): values for each of the labels
+ value: value of the metric at with these labels
+
+ Returns:
+ iterable[str]: rendered metric
+ """
+ rendered_labels = self._render_key(label_values)
+ return (
+ "%s%s %.12g" % (name, rendered_labels, value)
+ for name in self._names
+ )
+
+ def render(self):
+ """Render this metric
+
+ Each metric is rendered as:
+
+ name{label1="val1",label2="val2"} value
+
+ https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details
+
+ Returns:
+ iterable[str]: rendered metrics
+ """
+ raise NotImplementedError()
+
class CounterMetric(BaseMetric):
"""The simplest kind of metric; one that stores a monotonically-increasing
- integer that counts events."""
+ value that counts events or running totals.
+
+ Example use cases for Counters:
+ - Number of requests processed
+ - Number of items that were inserted into a queue
+ - Total amount of data that a system has processed
+ Counters can only go up (and be reset when the process restarts).
+ """
def __init__(self, *args, **kwargs):
super(CounterMetric, self).__init__(*args, **kwargs)
+ # dict[list[str]]: value for each set of label values. the keys are the
+ # label values, in the same order as the labels in self.labels.
+ #
+ # (if the metric is a scalar, the (single) key is the empty list).
self.counts = {}
# Scalar metrics are never empty
if self.is_scalar():
- self.counts[()] = 0
+ self.counts[()] = 0.
def inc_by(self, incr, *values):
if len(values) != self.dimension():
@@ -77,11 +135,11 @@ class CounterMetric(BaseMetric):
def inc(self, *values):
self.inc_by(1, *values)
- def render_item(self, k):
- return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
-
def render(self):
- return map_concat(self.render_item, sorted(self.counts.keys()))
+ return flatten(
+ self._render_for_labels(k, self.counts[k])
+ for k in sorted(self.counts.keys())
+ )
class CallbackMetric(BaseMetric):
@@ -98,10 +156,12 @@ class CallbackMetric(BaseMetric):
value = self.callback()
if self.is_scalar():
- return ["%s %.12g" % (self.name, value)]
+ return list(self._render_for_labels([], value))
- return ["%s%s %.12g" % (self.name, self._render_key(k), value[k])
- for k in sorted(value.keys())]
+ return flatten(
+ self._render_for_labels(k, value[k])
+ for k in sorted(value.keys())
+ )
class DistributionMetric(object):
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index d59503b905..0a9a290af4 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -517,25 +517,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_error("Wrong remote")
def on_RDATA(self, cmd):
+ stream_name = cmd.stream_name
+ inbound_rdata_count.inc(stream_name)
+
try:
- row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row)
+ row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
except Exception:
logger.exception(
"[%s] Failed to parse RDATA: %r %r",
- self.id(), cmd.stream_name, cmd.row
+ self.id(), stream_name, cmd.row
)
raise
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
- self.pending_batches.setdefault(cmd.stream_name, []).append(row)
+ self.pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
- rows = self.pending_batches.pop(cmd.stream_name, [])
+ rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
- self.handler.on_rdata(cmd.stream_name, cmd.token, rows)
+ self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd):
self.handler.on_position(cmd.stream_name, cmd.token)
@@ -644,3 +647,9 @@ metrics.register_callback(
},
labels=["command", "name", "conn_id"],
)
+
+# number of updates received for each RDATA stream
+inbound_rdata_count = metrics.register_counter(
+ "inbound_rdata_count",
+ labels=["stream_name"],
+)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 95fa95fce3..e7ac01da01 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -70,38 +70,11 @@ def respond_with_file(request, media_type, file_path,
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
- request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
- if upload_name:
- if is_ascii(upload_name):
- request.setHeader(
- b"Content-Disposition",
- b"inline; filename=%s" % (
- urllib.quote(upload_name.encode("utf-8")),
- ),
- )
- else:
- request.setHeader(
- b"Content-Disposition",
- b"inline; filename*=utf-8''%s" % (
- urllib.quote(upload_name.encode("utf-8")),
- ),
- )
-
- # cache for at least a day.
- # XXX: we might want to turn this off for data we don't want to
- # recommend caching as it's sensitive or private - or at least
- # select private. don't bother setting Expires as all our
- # clients are smart enough to be happy with Cache-Control
- request.setHeader(
- b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
- )
if file_size is None:
stat = os.stat(file_path)
file_size = stat.st_size
- request.setHeader(
- b"Content-Length", b"%d" % (file_size,)
- )
+ add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
yield logcontext.make_deferred_yieldable(
@@ -111,3 +84,118 @@ def respond_with_file(request, media_type, file_path,
finish_request(request)
else:
respond_404(request)
+
+
+def add_file_headers(request, media_type, file_size, upload_name):
+ """Adds the correct response headers in preparation for responding with the
+ media.
+
+ Args:
+ request (twisted.web.http.Request)
+ media_type (str): The media/content type.
+ file_size (int): Size in bytes of the media, if known.
+ upload_name (str): The name of the requested file, if any.
+ """
+ request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
+ if upload_name:
+ if is_ascii(upload_name):
+ request.setHeader(
+ b"Content-Disposition",
+ b"inline; filename=%s" % (
+ urllib.quote(upload_name.encode("utf-8")),
+ ),
+ )
+ else:
+ request.setHeader(
+ b"Content-Disposition",
+ b"inline; filename*=utf-8''%s" % (
+ urllib.quote(upload_name.encode("utf-8")),
+ ),
+ )
+
+ # cache for at least a day.
+ # XXX: we might want to turn this off for data we don't want to
+ # recommend caching as it's sensitive or private - or at least
+ # select private. don't bother setting Expires as all our
+ # clients are smart enough to be happy with Cache-Control
+ request.setHeader(
+ b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
+ )
+
+ request.setHeader(
+ b"Content-Length", b"%d" % (file_size,)
+ )
+
+
+@defer.inlineCallbacks
+def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
+ """Responds to the request with given responder. If responder is None then
+ returns 404.
+
+ Args:
+ request (twisted.web.http.Request)
+ responder (Responder|None)
+ media_type (str): The media/content type.
+ file_size (int|None): Size in bytes of the media. If not known it should be None
+ upload_name (str|None): The name of the requested file, if any.
+ """
+ if not responder:
+ respond_404(request)
+ return
+
+ add_file_headers(request, media_type, file_size, upload_name)
+ with responder:
+ yield responder.write_to_consumer(request)
+ finish_request(request)
+
+
+class Responder(object):
+ """Represents a response that can be streamed to the requester.
+
+ Responder is a context manager which *must* be used, so that any resources
+ held can be cleaned up.
+ """
+ def write_to_consumer(self, consumer):
+ """Stream response into consumer
+
+ Args:
+ consumer (IConsumer)
+
+ Returns:
+ Deferred: Resolves once the response has finished being written
+ """
+ pass
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+
+class FileInfo(object):
+ """Details about a requested/uploaded file.
+
+ Attributes:
+ server_name (str): The server name where the media originated from,
+ or None if local.
+ file_id (str): The local ID of the file. For local files this is the
+ same as the media_id
+ url_cache (bool): If the file is for the url preview cache
+ thumbnail (bool): Whether the file is a thumbnail or not.
+ thumbnail_width (int)
+ thumbnail_height (int)
+ thumbnail_method (str)
+ thumbnail_type (str): Content type of thumbnail, e.g. image/png
+ """
+ def __init__(self, server_name, file_id, url_cache=False,
+ thumbnail=False, thumbnail_width=None, thumbnail_height=None,
+ thumbnail_method=None, thumbnail_type=None):
+ self.server_name = server_name
+ self.file_id = file_id
+ self.url_cache = url_cache
+ self.thumbnail = thumbnail
+ self.thumbnail_width = thumbnail_width
+ self.thumbnail_height = thumbnail_height
+ self.thumbnail_method = thumbnail_method
+ self.thumbnail_type = thumbnail_type
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 6879249c8a..fe7e17596f 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -14,7 +14,7 @@
# limitations under the License.
import synapse.http.servlet
-from ._base import parse_media_id, respond_with_file, respond_404
+from ._base import parse_media_id, respond_404
from twisted.web.resource import Resource
from synapse.http.server import request_handler, set_cors_headers
@@ -32,12 +32,12 @@ class DownloadResource(Resource):
def __init__(self, hs, media_repo):
Resource.__init__(self)
- self.filepaths = media_repo.filepaths
self.media_repo = media_repo
self.server_name = hs.hostname
- self.store = hs.get_datastore()
- self.version_string = hs.version_string
+
+ # Both of these are expected by @request_handler()
self.clock = hs.get_clock()
+ self.version_string = hs.version_string
def render_GET(self, request):
self._async_render_GET(request)
@@ -57,59 +57,16 @@ class DownloadResource(Resource):
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
- yield self._respond_local_file(request, media_id, name)
+ yield self.media_repo.get_local_media(request, media_id, name)
else:
- yield self._respond_remote_file(
- request, server_name, media_id, name
- )
-
- @defer.inlineCallbacks
- def _respond_local_file(self, request, media_id, name):
- media_info = yield self.store.get_local_media(media_id)
- if not media_info or media_info["quarantined_by"]:
- respond_404(request)
- return
-
- media_type = media_info["media_type"]
- media_length = media_info["media_length"]
- upload_name = name if name else media_info["upload_name"]
- if media_info["url_cache"]:
- # TODO: Check the file still exists, if it doesn't we can redownload
- # it from the url `media_info["url_cache"]`
- file_path = self.filepaths.url_cache_filepath(media_id)
- else:
- file_path = self.filepaths.local_media_filepath(media_id)
-
- yield respond_with_file(
- request, media_type, file_path, media_length,
- upload_name=upload_name,
- )
-
- @defer.inlineCallbacks
- def _respond_remote_file(self, request, server_name, media_id, name):
- # don't forward requests for remote media if allow_remote is false
- allow_remote = synapse.http.servlet.parse_boolean(
- request, "allow_remote", default=True)
- if not allow_remote:
- logger.info(
- "Rejecting request for remote media %s/%s due to allow_remote",
- server_name, media_id,
- )
- respond_404(request)
- return
-
- media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
- media_type = media_info["media_type"]
- media_length = media_info["media_length"]
- filesystem_id = media_info["filesystem_id"]
- upload_name = name if name else media_info["upload_name"]
-
- file_path = self.filepaths.remote_media_filepath(
- server_name, filesystem_id
- )
-
- yield respond_with_file(
- request, media_type, file_path, media_length,
- upload_name=upload_name,
- )
+ allow_remote = synapse.http.servlet.parse_boolean(
+ request, "allow_remote", default=True)
+ if not allow_remote:
+ logger.info(
+ "Rejecting request for remote media %s/%s due to allow_remote",
+ server_name, media_id,
+ )
+ respond_404(request)
+ return
+
+ yield self.media_repo.get_remote_media(request, server_name, media_id, name)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index eed9056a2f..b2c76440b7 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +19,7 @@ import twisted.internet.error
import twisted.web.http
from twisted.web.resource import Resource
+from ._base import respond_404, FileInfo, respond_with_responder
from .upload_resource import UploadResource
from .download_resource import DownloadResource
from .thumbnail_resource import ThumbnailResource
@@ -25,6 +27,10 @@ from .identicon_resource import IdenticonResource
from .preview_url_resource import PreviewUrlResource
from .filepath import MediaFilePaths
from .thumbnailer import Thumbnailer
+from .storage_provider import (
+ StorageProviderWrapper, FileStorageProviderBackend,
+)
+from .media_storage import MediaStorage
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.util.stringutils import random_string
@@ -33,7 +39,7 @@ from synapse.api.errors import SynapseError, HttpResponseException, \
from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination
import os
@@ -47,7 +53,7 @@ import urlparse
logger = logging.getLogger(__name__)
-UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
+UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository(object):
@@ -63,96 +69,64 @@ class MediaRepository(object):
self.primary_base_path = hs.config.media_store_path
self.filepaths = MediaFilePaths(self.primary_base_path)
- self.backup_base_path = hs.config.backup_media_store_path
-
- self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
-
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set()
+ self.recently_accessed_locals = set()
- self.clock.looping_call(
- self._update_recently_accessed_remotes,
- UPDATE_RECENTLY_ACCESSED_REMOTES_TS
- )
+ # List of StorageProviders where we should search for media and
+ # potentially upload to.
+ storage_providers = []
- @defer.inlineCallbacks
- def _update_recently_accessed_remotes(self):
- media = self.recently_accessed_remotes
- self.recently_accessed_remotes = set()
+ # TODO: Move this into config and allow other storage providers to be
+ # defined.
+ if hs.config.backup_media_store_path:
+ backend = FileStorageProviderBackend(
+ self.primary_base_path, hs.config.backup_media_store_path,
+ )
+ provider = StorageProviderWrapper(
+ backend,
+ store=True,
+ store_synchronous=hs.config.synchronous_backup_media_store,
+ store_remote=True,
+ )
+ storage_providers.append(provider)
- yield self.store.update_cached_last_access_time(
- media, self.clock.time_msec()
+ self.media_storage = MediaStorage(
+ self.primary_base_path, self.filepaths, storage_providers,
)
- @staticmethod
- def _makedirs(filepath):
- dirname = os.path.dirname(filepath)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
-
- @staticmethod
- def _write_file_synchronously(source, fname):
- """Write `source` to the path `fname` synchronously. Should be called
- from a thread.
-
- Args:
- source: A file like object to be written
- fname (str): Path to write to
- """
- MediaRepository._makedirs(fname)
- source.seek(0) # Ensure we read from the start of the file
- with open(fname, "wb") as f:
- shutil.copyfileobj(source, f)
+ self.clock.looping_call(
+ self._update_recently_accessed,
+ UPDATE_RECENTLY_ACCESSED_TS,
+ )
@defer.inlineCallbacks
- def write_to_file_and_backup(self, source, path):
- """Write `source` to the on disk media store, and also the backup store
- if configured.
-
- Args:
- source: A file like object that should be written
- path (str): Relative path to write file to
-
- Returns:
- Deferred[str]: the file path written to in the primary media store
- """
- fname = os.path.join(self.primary_base_path, path)
-
- # Write to the main repository
- yield make_deferred_yieldable(threads.deferToThread(
- self._write_file_synchronously, source, fname,
- ))
+ def _update_recently_accessed(self):
+ remote_media = self.recently_accessed_remotes
+ self.recently_accessed_remotes = set()
- # Write to backup repository
- yield self.copy_to_backup(path)
+ local_media = self.recently_accessed_locals
+ self.recently_accessed_locals = set()
- defer.returnValue(fname)
+ yield self.store.update_cached_last_access_time(
+ local_media, remote_media, self.clock.time_msec()
+ )
- @defer.inlineCallbacks
- def copy_to_backup(self, path):
- """Copy a file from the primary to backup media store, if configured.
+ def mark_recently_accessed(self, server_name, media_id):
+ """Mark the given media as recently accessed.
Args:
- path(str): Relative path to write file to
+ server_name (str|None): Origin server of media, or None if local
+ media_id (str): The media ID of the content
"""
- if self.backup_base_path:
- primary_fname = os.path.join(self.primary_base_path, path)
- backup_fname = os.path.join(self.backup_base_path, path)
-
- # We can either wait for successful writing to the backup repository
- # or write in the background and immediately return
- if self.synchronous_backup_media_store:
- yield make_deferred_yieldable(threads.deferToThread(
- shutil.copyfile, primary_fname, backup_fname,
- ))
- else:
- preserve_fn(threads.deferToThread)(
- shutil.copyfile, primary_fname, backup_fname,
- )
+ if server_name:
+ self.recently_accessed_remotes.add((server_name, media_id))
+ else:
+ self.recently_accessed_locals.add(media_id)
@defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length,
@@ -171,10 +145,13 @@ class MediaRepository(object):
"""
media_id = random_string(24)
- fname = yield self.write_to_file_and_backup(
- content, self.filepaths.local_media_filepath_rel(media_id)
+ file_info = FileInfo(
+ server_name=None,
+ file_id=media_id,
)
+ fname = yield self.media_storage.store_file(content, file_info)
+
logger.info("Stored local media in file %r", fname)
yield self.store.store_local_media(
@@ -185,134 +162,263 @@ class MediaRepository(object):
media_length=content_length,
user_id=auth_user,
)
- media_info = {
- "media_type": media_type,
- "media_length": content_length,
- }
- yield self._generate_thumbnails(None, media_id, media_info)
+ yield self._generate_thumbnails(
+ None, media_id, media_id, media_type,
+ )
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@defer.inlineCallbacks
- def get_remote_media(self, server_name, media_id):
+ def get_local_media(self, request, media_id, name):
+ """Responds to reqests for local media, if exists, or returns 404.
+
+ Args:
+ request(twisted.web.http.Request)
+ media_id (str): The media ID of the content. (This is the same as
+ the file_id for local content.)
+ name (str|None): Optional name that, if specified, will be used as
+ the filename in the Content-Disposition header of the response.
+
+ Returns:
+ Deferred: Resolves once a response has successfully been written
+ to request
+ """
+ media_info = yield self.store.get_local_media(media_id)
+ if not media_info or media_info["quarantined_by"]:
+ respond_404(request)
+ return
+
+ self.mark_recently_accessed(None, media_id)
+
+ media_type = media_info["media_type"]
+ media_length = media_info["media_length"]
+ upload_name = name if name else media_info["upload_name"]
+ url_cache = media_info["url_cache"]
+
+ file_info = FileInfo(
+ None, media_id,
+ url_cache=url_cache,
+ )
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ yield respond_with_responder(
+ request, responder, media_type, media_length, upload_name,
+ )
+
+ @defer.inlineCallbacks
+ def get_remote_media(self, request, server_name, media_id, name):
+ """Respond to requests for remote media.
+
+ Args:
+ request(twisted.web.http.Request)
+ server_name (str): Remote server_name where the media originated.
+ media_id (str): The media ID of the content (as defined by the
+ remote server).
+ name (str|None): Optional name that, if specified, will be used as
+ the filename in the Content-Disposition header of the response.
+
+ Returns:
+ Deferred: Resolves once a response has successfully been written
+ to request
+ """
+ self.mark_recently_accessed(server_name, media_id)
+
+ # We linearize here to ensure that we don't try and download remote
+ # media multiple times concurrently
+ key = (server_name, media_id)
+ with (yield self.remote_media_linearizer.queue(key)):
+ responder, media_info = yield self._get_remote_media_impl(
+ server_name, media_id,
+ )
+
+ # We deliberately stream the file outside the lock
+ if responder:
+ media_type = media_info["media_type"]
+ media_length = media_info["media_length"]
+ upload_name = name if name else media_info["upload_name"]
+ yield respond_with_responder(
+ request, responder, media_type, media_length, upload_name,
+ )
+ else:
+ respond_404(request)
+
+ @defer.inlineCallbacks
+ def get_remote_media_info(self, server_name, media_id):
+ """Gets the media info associated with the remote file, downloading
+ if necessary.
+
+ Args:
+ server_name (str): Remote server_name where the media originated.
+ media_id (str): The media ID of the content (as defined by the
+ remote server).
+
+ Returns:
+ Deferred[dict]: The media_info of the file
+ """
+ # We linearize here to ensure that we don't try and download remote
+ # media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
- media_info = yield self._get_remote_media_impl(server_name, media_id)
+ responder, media_info = yield self._get_remote_media_impl(
+ server_name, media_id,
+ )
+
+ # Ensure we actually use the responder so that it releases resources
+ if responder:
+ with responder:
+ pass
+
defer.returnValue(media_info)
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
+ """Looks for media in local cache, if not there then attempt to
+ download from remote server.
+
+ Args:
+ server_name (str): Remote server_name where the media originated.
+ media_id (str): The media ID of the content (as defined by the
+ remote server).
+
+ Returns:
+ Deferred[(Responder, media_info)]
+ """
media_info = yield self.store.get_cached_remote_media(
server_name, media_id
)
- if not media_info:
- media_info = yield self._download_remote_file(
- server_name, media_id
- )
- elif media_info["quarantined_by"]:
- raise NotFoundError()
+
+ # file_id is the ID we use to track the file locally. If we've already
+ # seen the file then reuse the existing ID, otherwise genereate a new
+ # one.
+ if media_info:
+ file_id = media_info["filesystem_id"]
else:
- self.recently_accessed_remotes.add((server_name, media_id))
- yield self.store.update_cached_last_access_time(
- [(server_name, media_id)], self.clock.time_msec()
- )
- defer.returnValue(media_info)
+ file_id = random_string(24)
+
+ file_info = FileInfo(server_name, file_id)
+
+ # If we have an entry in the DB, try and look for it
+ if media_info:
+ if media_info["quarantined_by"]:
+ logger.info("Media is quarantined")
+ raise NotFoundError()
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ if responder:
+ defer.returnValue((responder, media_info))
+
+ # Failed to find the file anywhere, lets download it.
+
+ media_info = yield self._download_remote_file(
+ server_name, media_id, file_id
+ )
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ defer.returnValue((responder, media_info))
@defer.inlineCallbacks
- def _download_remote_file(self, server_name, media_id):
- file_id = random_string(24)
+ def _download_remote_file(self, server_name, media_id, file_id):
+ """Attempt to download the remote file from the given server name,
+ using the given file_id as the local id.
+
+ Args:
+ server_name (str): Originating server
+ media_id (str): The media ID of the content (as defined by the
+ remote server). This is different than the file_id, which is
+ locally generated.
+ file_id (str): Local file ID
+
+ Returns:
+ Deferred[MediaInfo]
+ """
- fpath = self.filepaths.remote_media_filepath_rel(
- server_name, file_id
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
)
- fname = os.path.join(self.primary_base_path, fpath)
- self._makedirs(fname)
- try:
- with open(fname, "wb") as f:
- request_path = "/".join((
- "/_matrix/media/v1/download", server_name, media_id,
- ))
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ request_path = "/".join((
+ "/_matrix/media/v1/download", server_name, media_id,
+ ))
+ try:
+ length, headers = yield self.client.get_file(
+ server_name, request_path, output_stream=f,
+ max_size=self.max_upload_size, args={
+ # tell the remote server to 404 if it doesn't
+ # recognise the server_name, to make sure we don't
+ # end up with a routing loop.
+ "allow_remote": "false",
+ }
+ )
+ except twisted.internet.error.DNSLookupError as e:
+ logger.warn("HTTP error fetching remote media %s/%s: %r",
+ server_name, media_id, e)
+ raise NotFoundError()
+
+ except HttpResponseException as e:
+ logger.warn("HTTP error fetching remote media %s/%s: %s",
+ server_name, media_id, e.response)
+ if e.code == twisted.web.http.NOT_FOUND:
+ raise SynapseError.from_http_response_exception(e)
+ raise SynapseError(502, "Failed to fetch remote media")
+
+ except SynapseError:
+ logger.exception("Failed to fetch remote media %s/%s",
+ server_name, media_id)
+ raise
+ except NotRetryingDestination:
+ logger.warn("Not retrying destination %r", server_name)
+ raise SynapseError(502, "Failed to fetch remote media")
+ except Exception:
+ logger.exception("Failed to fetch remote media %s/%s",
+ server_name, media_id)
+ raise SynapseError(502, "Failed to fetch remote media")
+
+ yield finish()
+
+ media_type = headers["Content-Type"][0]
+
+ time_now_ms = self.clock.time_msec()
+
+ content_disposition = headers.get("Content-Disposition", None)
+ if content_disposition:
+ _, params = cgi.parse_header(content_disposition[0],)
+ upload_name = None
+
+ # First check if there is a valid UTF-8 filename
+ upload_name_utf8 = params.get("filename*", None)
+ if upload_name_utf8:
+ if upload_name_utf8.lower().startswith("utf-8''"):
+ upload_name = upload_name_utf8[7:]
+
+ # If there isn't check for an ascii name.
+ if not upload_name:
+ upload_name_ascii = params.get("filename", None)
+ if upload_name_ascii and is_ascii(upload_name_ascii):
+ upload_name = upload_name_ascii
+
+ if upload_name:
+ upload_name = urlparse.unquote(upload_name)
try:
- length, headers = yield self.client.get_file(
- server_name, request_path, output_stream=f,
- max_size=self.max_upload_size, args={
- # tell the remote server to 404 if it doesn't
- # recognise the server_name, to make sure we don't
- # end up with a routing loop.
- "allow_remote": "false",
- }
- )
- except twisted.internet.error.DNSLookupError as e:
- logger.warn("HTTP error fetching remote media %s/%s: %r",
- server_name, media_id, e)
- raise NotFoundError()
-
- except HttpResponseException as e:
- logger.warn("HTTP error fetching remote media %s/%s: %s",
- server_name, media_id, e.response)
- if e.code == twisted.web.http.NOT_FOUND:
- raise SynapseError.from_http_response_exception(e)
- raise SynapseError(502, "Failed to fetch remote media")
-
- except SynapseError:
- logger.exception("Failed to fetch remote media %s/%s",
- server_name, media_id)
- raise
- except NotRetryingDestination:
- logger.warn("Not retrying destination %r", server_name)
- raise SynapseError(502, "Failed to fetch remote media")
- except Exception:
- logger.exception("Failed to fetch remote media %s/%s",
- server_name, media_id)
- raise SynapseError(502, "Failed to fetch remote media")
-
- yield self.copy_to_backup(fpath)
-
- media_type = headers["Content-Type"][0]
- time_now_ms = self.clock.time_msec()
-
- content_disposition = headers.get("Content-Disposition", None)
- if content_disposition:
- _, params = cgi.parse_header(content_disposition[0],)
- upload_name = None
-
- # First check if there is a valid UTF-8 filename
- upload_name_utf8 = params.get("filename*", None)
- if upload_name_utf8:
- if upload_name_utf8.lower().startswith("utf-8''"):
- upload_name = upload_name_utf8[7:]
-
- # If there isn't check for an ascii name.
- if not upload_name:
- upload_name_ascii = params.get("filename", None)
- if upload_name_ascii and is_ascii(upload_name_ascii):
- upload_name = upload_name_ascii
-
- if upload_name:
- upload_name = urlparse.unquote(upload_name)
- try:
- upload_name = upload_name.decode("utf-8")
- except UnicodeDecodeError:
- upload_name = None
- else:
- upload_name = None
-
- logger.info("Stored remote media in file %r", fname)
-
- yield self.store.store_cached_remote_media(
- origin=server_name,
- media_id=media_id,
- media_type=media_type,
- time_now_ms=self.clock.time_msec(),
- upload_name=upload_name,
- media_length=length,
- filesystem_id=file_id,
- )
- except Exception:
- os.remove(fname)
- raise
+ upload_name = upload_name.decode("utf-8")
+ except UnicodeDecodeError:
+ upload_name = None
+ else:
+ upload_name = None
+
+ logger.info("Stored remote media in file %r", fname)
+
+ yield self.store.store_cached_remote_media(
+ origin=server_name,
+ media_id=media_id,
+ media_type=media_type,
+ time_now_ms=self.clock.time_msec(),
+ upload_name=upload_name,
+ media_length=length,
+ filesystem_id=file_id,
+ )
media_info = {
"media_type": media_type,
@@ -323,7 +429,7 @@ class MediaRepository(object):
}
yield self._generate_thumbnails(
- server_name, media_id, media_info
+ server_name, media_id, file_id, media_type,
)
defer.returnValue(media_info)
@@ -368,11 +474,18 @@ class MediaRepository(object):
if t_byte_source:
try:
- output_path = yield self.write_to_file_and_backup(
- t_byte_source,
- self.filepaths.local_media_thumbnail_rel(
- media_id, t_width, t_height, t_type, t_method
- )
+ file_info = FileInfo(
+ server_name=None,
+ file_id=media_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ )
+
+ output_path = yield self.media_storage.store_file(
+ t_byte_source, file_info,
)
finally:
t_byte_source.close()
@@ -400,11 +513,18 @@ class MediaRepository(object):
if t_byte_source:
try:
- output_path = yield self.write_to_file_and_backup(
- t_byte_source,
- self.filepaths.remote_media_thumbnail_rel(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=media_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ )
+
+ output_path = yield self.media_storage.store_file(
+ t_byte_source, file_info,
)
finally:
t_byte_source.close()
@@ -421,21 +541,22 @@ class MediaRepository(object):
defer.returnValue(output_path)
@defer.inlineCallbacks
- def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False):
+ def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
+ url_cache=False):
"""Generate and store thumbnails for an image.
Args:
- server_name(str|None): The server name if remote media, else None if local
- media_id(str)
- media_info(dict)
- url_cache(bool): If we are thumbnailing images downloaded for the URL cache,
+ server_name (str|None): The server name if remote media, else None if local
+ media_id (str): The media ID of the content. (This is the same as
+ the file_id for local content)
+ file_id (str): Local file ID
+ media_type (str): The content type of the file
+ url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer
Returns:
Deferred[dict]: Dict with "width" and "height" keys of original image
"""
- media_type = media_info["media_type"]
- file_id = media_info.get("filesystem_id")
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
@@ -472,20 +593,6 @@ class MediaRepository(object):
# Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
- # Work out the correct file name for thumbnail
- if server_name:
- file_path = self.filepaths.remote_media_thumbnail_rel(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
- elif url_cache:
- file_path = self.filepaths.url_cache_thumbnail_rel(
- media_id, t_width, t_height, t_type, t_method
- )
- else:
- file_path = self.filepaths.local_media_thumbnail_rel(
- media_id, t_width, t_height, t_type, t_method
- )
-
# Generate the thumbnail
if t_method == "crop":
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -505,9 +612,19 @@ class MediaRepository(object):
continue
try:
- # Write to disk
- output_path = yield self.write_to_file_and_backup(
- t_byte_source, file_path,
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ url_cache=url_cache,
+ )
+
+ output_path = yield self.media_storage.store_file(
+ t_byte_source, file_info,
)
finally:
t_byte_source.close()
@@ -620,7 +737,11 @@ class MediaRepositoryResource(Resource):
self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo))
- self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
+ self.putChild("thumbnail", ThumbnailResource(
+ hs, media_repo, media_repo.media_storage,
+ ))
self.putChild("identicon", IdenticonResource())
if hs.config.url_preview_enabled:
- self.putChild("preview_url", PreviewUrlResource(hs, media_repo))
+ self.putChild("preview_url", PreviewUrlResource(
+ hs, media_repo, media_repo.media_storage,
+ ))
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
new file mode 100644
index 0000000000..001e84578e
--- /dev/null
+++ b/synapse/rest/media/v1/media_storage.py
@@ -0,0 +1,228 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vecotr Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, threads
+from twisted.protocols.basic import FileSender
+
+from ._base import Responder
+
+from synapse.util.logcontext import make_deferred_yieldable
+
+import contextlib
+import os
+import logging
+import shutil
+import sys
+
+logger = logging.getLogger(__name__)
+
+
+class MediaStorage(object):
+ """Responsible for storing/fetching files from local sources.
+
+ Args:
+ local_media_directory (str): Base path where we store media on disk
+ filepaths (MediaFilePaths)
+ storage_providers ([StorageProvider]): List of StorageProvider that are
+ used to fetch and store files.
+ """
+
+ def __init__(self, local_media_directory, filepaths, storage_providers):
+ self.local_media_directory = local_media_directory
+ self.filepaths = filepaths
+ self.storage_providers = storage_providers
+
+ @defer.inlineCallbacks
+ def store_file(self, source, file_info):
+ """Write `source` to the on disk media store, and also any other
+ configured storage providers
+
+ Args:
+ source: A file like object that should be written
+ file_info (FileInfo): Info about the file to store
+
+ Returns:
+ Deferred[str]: the file path written to in the primary media store
+ """
+ path = self._file_info_to_path(file_info)
+ fname = os.path.join(self.local_media_directory, path)
+
+ dirname = os.path.dirname(fname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ # Write to the main repository
+ yield make_deferred_yieldable(threads.deferToThread(
+ _write_file_synchronously, source, fname,
+ ))
+
+ defer.returnValue(fname)
+
+ @contextlib.contextmanager
+ def store_into_file(self, file_info):
+ """Context manager used to get a file like object to write into, as
+ described by file_info.
+
+ Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
+ like object that can be written to, fname is the absolute path of file
+ on disk, and finish_cb is a function that returns a Deferred.
+
+ fname can be used to read the contents from after upload, e.g. to
+ generate thumbnails.
+
+ finish_cb must be called and waited on after the file has been
+ successfully been written to. Should not be called if there was an
+ error.
+
+ Args:
+ file_info (FileInfo): Info about the file to store
+
+ Example:
+
+ with media_storage.store_into_file(info) as (f, fname, finish_cb):
+ # .. write into f ...
+ yield finish_cb()
+ """
+
+ path = self._file_info_to_path(file_info)
+ fname = os.path.join(self.local_media_directory, path)
+
+ dirname = os.path.dirname(fname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ finished_called = [False]
+
+ @defer.inlineCallbacks
+ def finish():
+ for provider in self.storage_providers:
+ yield provider.store_file(path, file_info)
+
+ finished_called[0] = True
+
+ try:
+ with open(fname, "wb") as f:
+ yield f, fname, finish
+ except Exception:
+ t, v, tb = sys.exc_info()
+ try:
+ os.remove(fname)
+ except Exception:
+ pass
+ raise t, v, tb
+
+ if not finished_called:
+ raise Exception("Finished callback not called")
+
+ @defer.inlineCallbacks
+ def fetch_media(self, file_info):
+ """Attempts to fetch media described by file_info from the local cache
+ and configured storage providers.
+
+ Args:
+ file_info (FileInfo)
+
+ Returns:
+ Deferred[Responder|None]: Returns a Responder if the file was found,
+ otherwise None.
+ """
+
+ path = self._file_info_to_path(file_info)
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ defer.returnValue(FileResponder(open(local_path, "rb")))
+
+ for provider in self.storage_providers:
+ res = yield provider.fetch(path, file_info)
+ if res:
+ defer.returnValue(res)
+
+ defer.returnValue(None)
+
+ def _file_info_to_path(self, file_info):
+ """Converts file_info into a relative path.
+
+ The path is suitable for storing files under a directory, e.g. used to
+ store files on local FS under the base media repository directory.
+
+ Args:
+ file_info (FileInfo)
+
+ Returns:
+ str
+ """
+ if file_info.url_cache:
+ return self.filepaths.url_cache_filepath_rel(file_info.file_id)
+
+ if file_info.server_name:
+ if file_info.thumbnail:
+ return self.filepaths.remote_media_thumbnail_rel(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ method=file_info.thumbnail_method
+ )
+ return self.filepaths.remote_media_filepath_rel(
+ file_info.server_name, file_info.file_id,
+ )
+
+ if file_info.thumbnail:
+ return self.filepaths.local_media_thumbnail_rel(
+ media_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ method=file_info.thumbnail_method
+ )
+ return self.filepaths.local_media_filepath_rel(
+ file_info.file_id,
+ )
+
+
+def _write_file_synchronously(source, fname):
+ """Write `source` to the path `fname` synchronously. Should be called
+ from a thread.
+
+ Args:
+ source: A file like object to be written
+ fname (str): Path to write to
+ """
+ dirname = os.path.dirname(fname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ source.seek(0) # Ensure we read from the start of the file
+ with open(fname, "wb") as f:
+ shutil.copyfileobj(source, f)
+
+
+class FileResponder(Responder):
+ """Wraps an open file that can be sent to a request.
+
+ Args:
+ open_file (file): A file like object to be streamed ot the client,
+ is closed when finished streaming.
+ """
+ def __init__(self, open_file):
+ self.open_file = open_file
+
+ @defer.inlineCallbacks
+ def write_to_consumer(self, consumer):
+ yield FileSender().beginFileTransfer(self.open_file, consumer)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.open_file.close()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 40d2e664eb..981f01e417 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -17,6 +17,8 @@ from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
from twisted.web.resource import Resource
+from ._base import FileInfo
+
from synapse.api.errors import (
SynapseError, Codes,
)
@@ -49,7 +51,7 @@ logger = logging.getLogger(__name__)
class PreviewUrlResource(Resource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self)
self.auth = hs.get_auth()
@@ -62,6 +64,7 @@ class PreviewUrlResource(Resource):
self.client = SpiderHttpClient(hs)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
+ self.media_storage = media_storage
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
@@ -182,8 +185,10 @@ class PreviewUrlResource(Resource):
logger.debug("got media_info of '%s'" % media_info)
if _is_media(media_info['media_type']):
+ file_id = media_info['filesystem_id']
dims = yield self.media_repo._generate_thumbnails(
- None, media_info['filesystem_id'], media_info, url_cache=True,
+ None, file_id, file_id, media_info["media_type"],
+ url_cache=True,
)
og = {
@@ -228,8 +233,10 @@ class PreviewUrlResource(Resource):
if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
+ file_id = image_info['filesystem_id']
dims = yield self.media_repo._generate_thumbnails(
- None, image_info['filesystem_id'], image_info, url_cache=True,
+ None, file_id, file_id, image_info["media_type"],
+ url_cache=True,
)
if dims:
og["og:image:width"] = dims['width']
@@ -273,19 +280,21 @@ class PreviewUrlResource(Resource):
file_id = datetime.date.today().isoformat() + '_' + random_string(16)
- fpath = self.filepaths.url_cache_filepath_rel(file_id)
- fname = os.path.join(self.primary_base_path, fpath)
- self.media_repo._makedirs(fname)
+ file_info = FileInfo(
+ server_name=None,
+ file_id=file_id,
+ url_cache=True,
+ )
try:
- with open(fname, "wb") as f:
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
logger.debug("Trying to get url '%s'" % url)
length, headers, uri, code = yield self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size,
)
# FIXME: pass through 404s and other error messages nicely
- yield self.media_repo.copy_to_backup(fpath)
+ yield finish()
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
@@ -327,7 +336,6 @@ class PreviewUrlResource(Resource):
)
except Exception as e:
- os.remove(fname)
raise SynapseError(
500, ("Failed to download content: %s" % e),
Codes.UNKNOWN
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
new file mode 100644
index 0000000000..2ad602e101
--- /dev/null
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, threads
+
+from .media_storage import FileResponder
+
+from synapse.util.logcontext import preserve_fn
+
+import logging
+import os
+import shutil
+
+
+logger = logging.getLogger(__name__)
+
+
+class StorageProvider(object):
+ """A storage provider is a service that can store uploaded media and
+ retrieve them.
+ """
+ def store_file(self, path, file_info):
+ """Store the file described by file_info. The actual contents can be
+ retrieved by reading the file in file_info.upload_path.
+
+ Args:
+ path (str): Relative path of file in local cache
+ file_info (FileInfo)
+
+ Returns:
+ Deferred
+ """
+ pass
+
+ def fetch(self, path, file_info):
+ """Attempt to fetch the file described by file_info and stream it
+ into writer.
+
+ Args:
+ path (str): Relative path of file in local cache
+ file_info (FileInfo)
+
+ Returns:
+ Deferred(Responder): Returns a Responder if the provider has the file,
+ otherwise returns None.
+ """
+ pass
+
+
+class StorageProviderWrapper(StorageProvider):
+ """Wraps a storage provider and provides various config options
+
+ Args:
+ backend (StorageProvider)
+ store (bool): Whether to store new files or not.
+ store_synchronous (bool): Whether to wait for file to be successfully
+ uploaded, or todo the upload in the backgroud.
+ store_remote (bool): Whether remote media should be uploaded
+ """
+ def __init__(self, backend, store, store_synchronous, store_remote):
+ self.backend = backend
+ self.store = store
+ self.store_synchronous = store_synchronous
+ self.store_remote = store_remote
+
+ def store_file(self, path, file_info):
+ if not self.store:
+ return defer.succeed(None)
+
+ if file_info.server_name and not self.store_remote:
+ return defer.succeed(None)
+
+ if self.store_synchronous:
+ return self.backend.store_file(path, file_info)
+ else:
+ # TODO: Handle errors.
+ preserve_fn(self.backend.store_file)(path, file_info)
+ return defer.succeed(None)
+
+ def fetch(self, path, file_info):
+ return self.backend.fetch(path, file_info)
+
+
+class FileStorageProviderBackend(StorageProvider):
+ """A storage provider that stores files in a directory on a filesystem.
+
+ Args:
+ cache_directory (str): Base path of the local media repository
+ base_directory (str): Base path to store new files
+ """
+
+ def __init__(self, cache_directory, base_directory):
+ self.cache_directory = cache_directory
+ self.base_directory = base_directory
+
+ def store_file(self, path, file_info):
+ """See StorageProvider.store_file"""
+
+ primary_fname = os.path.join(self.cache_directory, path)
+ backup_fname = os.path.join(self.base_directory, path)
+
+ dirname = os.path.dirname(backup_fname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ return threads.deferToThread(
+ shutil.copyfile, primary_fname, backup_fname,
+ )
+
+ def fetch(self, path, file_info):
+ """See StorageProvider.fetch"""
+
+ backup_fname = os.path.join(self.base_directory, path)
+ if os.path.isfile(backup_fname):
+ return FileResponder(open(backup_fname, "rb"))
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 53e48aba22..c09f2dec41 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -14,7 +14,10 @@
# limitations under the License.
-from ._base import parse_media_id, respond_404, respond_with_file
+from ._base import (
+ parse_media_id, respond_404, respond_with_file, FileInfo,
+ respond_with_responder,
+)
from twisted.web.resource import Resource
from synapse.http.servlet import parse_string, parse_integer
from synapse.http.server import request_handler, set_cors_headers
@@ -30,12 +33,12 @@ logger = logging.getLogger(__name__)
class ThumbnailResource(Resource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self)
self.store = hs.get_datastore()
- self.filepaths = media_repo.filepaths
self.media_repo = media_repo
+ self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
self.version_string = hs.version_string
@@ -64,6 +67,7 @@ class ThumbnailResource(Resource):
yield self._respond_local_thumbnail(
request, media_id, width, height, method, m_type
)
+ self.media_repo.mark_recently_accessed(server_name, media_id)
else:
if self.dynamic_thumbnails:
yield self._select_or_generate_remote_thumbnail(
@@ -75,13 +79,18 @@ class ThumbnailResource(Resource):
request, server_name, media_id,
width, height, method, m_type
)
+ self.media_repo.mark_recently_accessed(None, media_id)
@defer.inlineCallbacks
def _respond_local_thumbnail(self, request, media_id, width, height,
method, m_type):
media_info = yield self.store.get_local_media(media_id)
- if not media_info or media_info["quarantined_by"]:
+ if not media_info:
+ respond_404(request)
+ return
+ if media_info["quarantined_by"]:
+ logger.info("Media is quarantined")
respond_404(request)
return
@@ -91,24 +100,24 @@ class ThumbnailResource(Resource):
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
- t_width = thumbnail_info["thumbnail_width"]
- t_height = thumbnail_info["thumbnail_height"]
- t_type = thumbnail_info["thumbnail_type"]
- t_method = thumbnail_info["thumbnail_method"]
-
- if media_info["url_cache"]:
- # TODO: Check the file still exists, if it doesn't we can redownload
- # it from the url `media_info["url_cache"]`
- file_path = self.filepaths.url_cache_thumbnail(
- media_id, t_width, t_height, t_type, t_method,
- )
- else:
- file_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method,
- )
- yield respond_with_file(request, t_type, file_path)
+ file_info = FileInfo(
+ server_name=None, file_id=media_id,
+ url_cache=media_info["url_cache"],
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
+ )
+
+ t_type = file_info.thumbnail_type
+ t_length = thumbnail_info["thumbnail_length"]
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ yield respond_with_responder(request, responder, t_type, t_length)
else:
+ logger.info("Couldn't find any generated thumbnails")
respond_404(request)
@defer.inlineCallbacks
@@ -117,7 +126,11 @@ class ThumbnailResource(Resource):
desired_type):
media_info = yield self.store.get_local_media(media_id)
- if not media_info or media_info["quarantined_by"]:
+ if not media_info:
+ respond_404(request)
+ return
+ if media_info["quarantined_by"]:
+ logger.info("Media is quarantined")
respond_404(request)
return
@@ -129,22 +142,25 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type:
- if media_info["url_cache"]:
- # TODO: Check the file still exists, if it doesn't we can redownload
- # it from the url `media_info["url_cache"]`
- file_path = self.filepaths.url_cache_thumbnail(
- media_id, desired_width, desired_height, desired_type,
- desired_method,
- )
- else:
- file_path = self.filepaths.local_media_thumbnail(
- media_id, desired_width, desired_height, desired_type,
- desired_method,
- )
- yield respond_with_file(request, desired_type, file_path)
- return
-
- logger.debug("We don't have a local thumbnail of that size. Generating")
+ file_info = FileInfo(
+ server_name=None, file_id=media_id,
+ url_cache=media_info["url_cache"],
+ thumbnail=True,
+ thumbnail_width=info["thumbnail_width"],
+ thumbnail_height=info["thumbnail_height"],
+ thumbnail_type=info["thumbnail_type"],
+ thumbnail_method=info["thumbnail_method"],
+ )
+
+ t_type = file_info.thumbnail_type
+ t_length = info["thumbnail_length"]
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ if responder:
+ yield respond_with_responder(request, responder, t_type, t_length)
+ return
+
+ logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_local_exact_thumbnail(
@@ -154,13 +170,14 @@ class ThumbnailResource(Resource):
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
+ logger.warn("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
desired_width, desired_height,
desired_method, desired_type):
- media_info = yield self.media_repo.get_remote_media(server_name, media_id)
+ media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
@@ -175,14 +192,24 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type:
- file_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, desired_width, desired_height,
- desired_type, desired_method,
+ file_info = FileInfo(
+ server_name=server_name, file_id=media_info["filesystem_id"],
+ thumbnail=True,
+ thumbnail_width=info["thumbnail_width"],
+ thumbnail_height=info["thumbnail_height"],
+ thumbnail_type=info["thumbnail_type"],
+ thumbnail_method=info["thumbnail_method"],
)
- yield respond_with_file(request, desired_type, file_path)
- return
- logger.debug("We don't have a local thumbnail of that size. Generating")
+ t_type = file_info.thumbnail_type
+ t_length = info["thumbnail_length"]
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ if responder:
+ yield respond_with_responder(request, responder, t_type, t_length)
+ return
+
+ logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
@@ -193,6 +220,7 @@ class ThumbnailResource(Resource):
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
+ logger.warn("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
@@ -201,7 +229,7 @@ class ThumbnailResource(Resource):
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
- yield self.media_repo.get_remote_media(server_name, media_id)
+ media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
@@ -211,18 +239,22 @@ class ThumbnailResource(Resource):
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
- t_width = thumbnail_info["thumbnail_width"]
- t_height = thumbnail_info["thumbnail_height"]
- t_type = thumbnail_info["thumbnail_type"]
- t_method = thumbnail_info["thumbnail_method"]
- file_id = thumbnail_info["filesystem_id"]
+ file_info = FileInfo(
+ server_name=server_name, file_id=media_info["filesystem_id"],
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
+ )
+
+ t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
- file_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method,
- )
- yield respond_with_file(request, t_type, file_path, t_length)
+ responder = yield self.media_storage.fetch_media(file_info)
+ yield respond_with_responder(request, responder, t_type, t_length)
else:
+ logger.info("Failed to find any generated thumbnails")
respond_404(request)
def _select_thumbnail(self, desired_width, desired_height, desired_method,
diff --git a/synapse/state.py b/synapse/state.py
index 9e624b4937..1f9abf9d3d 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -341,7 +341,7 @@ class StateHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
- new_state = yield resolve_events(
+ new_state = yield resolve_events_with_factory(
state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
@@ -404,7 +404,7 @@ class StateHandler(object):
}
with Measure(self.clock, "state._resolve_events"):
- new_state = resolve_events(state_set_ids, state_map)
+ new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
@@ -420,19 +420,17 @@ def _ordered_events(events):
return sorted(events, key=key_func)
-def resolve_events(state_sets, state_map_factory):
+def resolve_events_with_state_map(state_sets, state_map):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
- state_map_factory(dict|callable): If callable, then will be called
- with a list of event_ids that are needed, and should return with
- a Deferred of dict of event_id to event. Otherwise, should be
- a dict from event_id to event of all events in state_sets.
+ state_map(dict): a dict from event_id to event, for all events in
+ state_sets.
Returns
- dict[(str, str), synapse.events.FrozenEvent] is a map from
- (type, state_key) to event.
+ dict[(str, str), synapse.events.FrozenEvent]:
+ a map from (type, state_key) to event.
"""
if len(state_sets) == 1:
return state_sets[0]
@@ -441,13 +439,6 @@ def resolve_events(state_sets, state_map_factory):
state_sets,
)
- if callable(state_map_factory):
- return _resolve_with_state_fac(
- unconflicted_state, conflicted_state, state_map_factory
- )
-
- state_map = state_map_factory
-
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
@@ -491,8 +482,26 @@ def _seperate(state_sets):
@defer.inlineCallbacks
-def _resolve_with_state_fac(unconflicted_state, conflicted_state,
- state_map_factory):
+def resolve_events_with_factory(state_sets, state_map_factory):
+ """
+ Args:
+ state_sets(list): List of dicts of (type, state_key) -> event_id,
+ which are the different state groups to resolve.
+ state_map_factory(func): will be called
+ with a list of event_ids that are needed, and should return with
+ a Deferred of dict of event_id to event.
+
+ Returns
+ Deferred[dict[(str, str), synapse.events.FrozenEvent]]:
+ a map from (type, state_key) to event.
+ """
+ if len(state_sets) == 1:
+ defer.returnValue(state_sets[0])
+
+ unconflicted_state, conflicted_state = _seperate(
+ state_sets,
+ )
+
needed_events = set(
event_id
for event_ids in conflicted_state.itervalues()
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b971f0cb18..68125006eb 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -291,33 +291,33 @@ class SQLBaseStore(object):
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
- current_context = LoggingContext.current_context()
+ """Starts a transaction on the database and runs a given function
- start_time = time.time() * 1000
+ Arguments:
+ desc (str): description of the transaction, for logging and metrics
+ func (func): callback function, which will be called with a
+ database transaction (twisted.enterprise.adbapi.Transaction) as
+ its first argument, followed by `args` and `kwargs`.
+
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
+ current_context = LoggingContext.current_context()
after_callbacks = []
final_callbacks = []
def inner_func(conn, *args, **kwargs):
- with LoggingContext("runInteraction") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
-
- if self.database_engine.is_connection_closed(conn):
- logger.debug("Reconnecting closed database connection")
- conn.reconnect()
-
- current_context.copy_to(context)
- return self._new_transaction(
- conn, desc, after_callbacks, final_callbacks, current_context,
- func, *args, **kwargs
- )
+ return self._new_transaction(
+ conn, desc, after_callbacks, final_callbacks, current_context,
+ func, *args, **kwargs
+ )
try:
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(
- inner_func, *args, **kwargs
- )
+ result = yield self.runWithConnection(inner_func, *args, **kwargs)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
@@ -329,14 +329,27 @@ class SQLBaseStore(object):
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
+ """Wraps the .runWithConnection() method on the underlying db_pool.
+
+ Arguments:
+ func (func): callback function, which will be called with a
+ database connection (twisted.enterprise.adbapi.Connection) as
+ its first argument, followed by `args` and `kwargs`.
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+ sched_duration_ms = time.time() * 1000 - start_time
+ sql_scheduling_timer.inc_by(sched_duration_ms)
+ current_context.add_database_scheduled(sched_duration_ms)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index ba0da83642..7a9cd3ec90 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -27,7 +27,7 @@ from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
-from synapse.state import resolve_events
+from synapse.state import resolve_events_with_factory
from synapse.util.caches.descriptors import cached
from synapse.types import get_domain_from_id
@@ -146,6 +146,9 @@ class _EventPeristenceQueue(object):
try:
queue = self._get_drainining_queue(room_id)
for item in queue:
+ # handle_queue_loop runs in the sentinel logcontext, so
+ # there is no need to preserve_fn when running the
+ # callbacks on the deferred.
try:
ret = yield per_item_callback(item)
item.deferred.callback(ret)
@@ -157,7 +160,11 @@ class _EventPeristenceQueue(object):
self._event_persist_queues[room_id] = queue
self._currently_persisting_rooms.discard(room_id)
- preserve_fn(handle_queue_loop)()
+ # set handle_queue_loop off on the background. We don't want to
+ # attribute work done in it to the current request, so we drop the
+ # logcontext altogether.
+ with PreserveLoggingContext():
+ handle_queue_loop()
def _get_drainining_queue(self, room_id):
queue = self._event_persist_queues.setdefault(room_id, deque())
@@ -556,7 +563,7 @@ class EventsStore(SQLBaseStore):
to_return.update(evs)
defer.returnValue(to_return)
- current_state = yield resolve_events(
+ current_state = yield resolve_events_with_factory(
state_sets,
state_map_factory=get_events,
)
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 6ebc372498..e6cdbb0545 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -173,7 +173,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
desc="store_cached_remote_media",
)
- def update_cached_last_access_time(self, origin_id_tuples, time_ts):
+ def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+ """Updates the last access time of the given media
+
+ Args:
+ local_media (iterable[str]): Set of media_ids
+ remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+ time_ms: Current time in milliseconds
+ """
def update_cache_txn(txn):
sql = (
"UPDATE remote_media_cache SET last_access_ts = ?"
@@ -181,8 +188,18 @@ class MediaRepositoryStore(BackgroundUpdateStore):
)
txn.executemany(sql, (
- (time_ts, media_origin, media_id)
- for media_origin, media_id in origin_id_tuples
+ (time_ms, media_origin, media_id)
+ for media_origin, media_id in remote_media
+ ))
+
+ sql = (
+ "UPDATE local_media_repository SET last_access_ts = ?"
+ " WHERE media_id = ?"
+ )
+
+ txn.executemany(sql, (
+ (time_ms, media_id)
+ for media_id in local_media
))
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index d1691bbac2..c845a0cec5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 46
+SCHEMA_VERSION = 47
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/schema/delta/47/last_access_media.sql
new file mode 100644
index 0000000000..f505fb22b5
--- /dev/null
+++ b/synapse/storage/schema/delta/47/last_access_media.sql
@@ -0,0 +1,16 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT;
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index c9bff408ef..f150ef0103 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -641,8 +641,13 @@ class UserDirectoryStore(SQLBaseStore):
"""
if self.hs.config.user_directory_search_all_users:
- join_clause = ""
- where_clause = "?<>''" # naughty hack to keep the same number of binds
+ # dummy to keep the number of binds & aliases the same
+ join_clause = """
+ LEFT JOIN (
+ SELECT NULL as user_id WHERE NULL = ?
+ ) AS s USING (user_id)"
+ """
+ where_clause = ""
else:
join_clause = """
LEFT JOIN users_in_public_rooms AS p USING (user_id)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 48c9f6802d..94fa7cac98 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -52,13 +52,17 @@ except Exception:
class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a
"with" block.
+
Args:
name (str): Name for the context for debugging.
"""
__slots__ = [
- "previous_context", "name", "usage_start", "usage_end", "main_thread",
- "__dict__", "tag", "alive",
+ "previous_context", "name", "ru_stime", "ru_utime",
+ "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+ "usage_start", "usage_end",
+ "main_thread", "alive",
+ "request", "tag",
]
thread_local = threading.local()
@@ -83,6 +87,9 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms):
pass
+ def add_database_scheduled(self, sched_ms):
+ pass
+
def __nonzero__(self):
return False
@@ -94,9 +101,17 @@ class LoggingContext(object):
self.ru_stime = 0.
self.ru_utime = 0.
self.db_txn_count = 0
- self.db_txn_duration = 0.
+
+ # ms spent waiting for db txns, excluding scheduling time
+ self.db_txn_duration_ms = 0
+
+ # ms spent waiting for db txns to be scheduled
+ self.db_sched_duration_ms = 0
+
self.usage_start = None
+ self.usage_end = None
self.main_thread = threading.current_thread()
+ self.request = None
self.tag = ""
self.alive = True
@@ -105,7 +120,11 @@ class LoggingContext(object):
@classmethod
def current_context(cls):
- """Get the current logging context from thread local storage"""
+ """Get the current logging context from thread local storage
+
+ Returns:
+ LoggingContext: the current logging context
+ """
return getattr(cls.thread_local, "current_context", cls.sentinel)
@classmethod
@@ -155,11 +174,13 @@ class LoggingContext(object):
self.alive = False
def copy_to(self, record):
- """Copy fields from this context to the record"""
- for key, value in self.__dict__.items():
- setattr(record, key, value)
+ """Copy logging fields from this context to a log record or
+ another LoggingContext
+ """
- record.ru_utime, record.ru_stime = self.get_resource_usage()
+ # 'request' is the only field we currently use in the logger, so that's
+ # all we need to copy
+ record.request = self.request
def start(self):
if threading.current_thread() is not self.main_thread:
@@ -194,7 +215,16 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms):
self.db_txn_count += 1
- self.db_txn_duration += duration_ms / 1000.
+ self.db_txn_duration_ms += duration_ms
+
+ def add_database_scheduled(self, sched_ms):
+ """Record a use of the database pool
+
+ Args:
+ sched_ms (int): number of milliseconds it took us to get a
+ connection
+ """
+ self.db_sched_duration_ms += sched_ms
class LoggingContextFilter(logging.Filter):
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4ea930d3e8..059bb7fedf 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -27,25 +27,62 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
-block_timer = metrics.register_distribution(
- "block_timer",
- labels=["block_name"]
+# total number of times we have hit this block
+response_count = metrics.register_counter(
+ "block_count",
+ labels=["block_name"],
+ alternative_names=(
+ # the following are all deprecated aliases for the same metric
+ metrics.name_prefix + x for x in (
+ "_block_timer:count",
+ "_block_ru_utime:count",
+ "_block_ru_stime:count",
+ "_block_db_txn_count:count",
+ "_block_db_txn_duration:count",
+ )
+ )
+)
+
+block_timer = metrics.register_counter(
+ "block_time_seconds",
+ labels=["block_name"],
+ alternative_names=(
+ metrics.name_prefix + "_block_timer:total",
+ ),
)
-block_ru_utime = metrics.register_distribution(
- "block_ru_utime", labels=["block_name"]
+block_ru_utime = metrics.register_counter(
+ "block_ru_utime_seconds", labels=["block_name"],
+ alternative_names=(
+ metrics.name_prefix + "_block_ru_utime:total",
+ ),
)
-block_ru_stime = metrics.register_distribution(
- "block_ru_stime", labels=["block_name"]
+block_ru_stime = metrics.register_counter(
+ "block_ru_stime_seconds", labels=["block_name"],
+ alternative_names=(
+ metrics.name_prefix + "_block_ru_stime:total",
+ ),
)
-block_db_txn_count = metrics.register_distribution(
- "block_db_txn_count", labels=["block_name"]
+block_db_txn_count = metrics.register_counter(
+ "block_db_txn_count", labels=["block_name"],
+ alternative_names=(
+ metrics.name_prefix + "_block_db_txn_count:total",
+ ),
)
-block_db_txn_duration = metrics.register_distribution(
- "block_db_txn_duration", labels=["block_name"]
+# seconds spent waiting for db txns, excluding scheduling time, in this block
+block_db_txn_duration = metrics.register_counter(
+ "block_db_txn_duration_seconds", labels=["block_name"],
+ alternative_names=(
+ metrics.name_prefix + "_block_db_txn_count:total",
+ ),
+)
+
+# seconds spent waiting for a db connection, in this block
+block_db_sched_duration = metrics.register_counter(
+ "block_db_sched_duration_seconds", labels=["block_name"],
)
@@ -64,7 +101,9 @@ def measure_func(name):
class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
- "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
+ "ru_stime",
+ "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+ "created_context",
]
def __init__(self, clock, name):
@@ -84,7 +123,8 @@ class Measure(object):
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
self.db_txn_count = self.start_context.db_txn_count
- self.db_txn_duration = self.start_context.db_txn_duration
+ self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
+ self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(exc_type, Exception) or not self.start_context:
@@ -114,7 +154,12 @@ class Measure(object):
context.db_txn_count - self.db_txn_count, self.name
)
block_db_txn_duration.inc_by(
- context.db_txn_duration - self.db_txn_duration, self.name
+ (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
+ self.name
+ )
+ block_db_sched_duration.inc_by(
+ (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
+ self.name
)
if self.created_context:
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 570312da84..c899fecf5d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -68,7 +68,7 @@ class KeyringTestCase(unittest.TestCase):
def check_context(self, _, expected):
self.assertEquals(
- getattr(LoggingContext.current_context(), "test_key", None),
+ getattr(LoggingContext.current_context(), "request", None),
expected
)
@@ -82,7 +82,7 @@ class KeyringTestCase(unittest.TestCase):
lookup_2_deferred = defer.Deferred()
with LoggingContext("one") as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
wait_1_deferred = kr.wait_for_previous_lookups(
["server1"],
@@ -96,7 +96,7 @@ class KeyringTestCase(unittest.TestCase):
wait_1_deferred.addBoth(self.check_context, "one")
with LoggingContext("two") as context_two:
- context_two.test_key = "two"
+ context_two.request = "two"
# set off another wait. It should block because the first lookup
# hasn't yet completed.
@@ -137,7 +137,7 @@ class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def get_perspectives(**kwargs):
self.assertEquals(
- LoggingContext.current_context().test_key, "11",
+ LoggingContext.current_context().request, "11",
)
with logcontext.PreserveLoggingContext():
yield persp_deferred
@@ -145,7 +145,7 @@ class KeyringTestCase(unittest.TestCase):
self.http_client.post_json.side_effect = get_perspectives
with LoggingContext("11") as context_11:
- context_11.test_key = "11"
+ context_11.request = "11"
# start off a first set of lookups
res_deferreds = kr.verify_json_objects_for_server(
@@ -173,7 +173,7 @@ class KeyringTestCase(unittest.TestCase):
self.assertIs(LoggingContext.current_context(), context_11)
context_12 = LoggingContext("12")
- context_12.test_key = "12"
+ context_12.request = "12"
with logcontext.PreserveLoggingContext(context_12):
# a second request for a server with outstanding requests
# should block rather than start a second call
@@ -211,7 +211,7 @@ class KeyringTestCase(unittest.TestCase):
sentinel_context = LoggingContext.current_context()
with LoggingContext("one") as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
defer = kr.verify_json_for_server("server9", {})
try:
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index e2f7765f49..4850722bc5 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -12,12 +12,12 @@ class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value):
self.assertEquals(
- LoggingContext.current_context().test_key, value
+ LoggingContext.current_context().request, value
)
def test_with_context(self):
with LoggingContext() as context_one:
- context_one.test_key = "test"
+ context_one.request = "test"
self._check_test_key("test")
@defer.inlineCallbacks
@@ -25,14 +25,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def competing_callback():
with LoggingContext() as competing_context:
- competing_context.test_key = "competing"
+ competing_context.request = "competing"
yield sleep(0)
self._check_test_key("competing")
reactor.callLater(0, competing_callback)
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
yield sleep(0)
self._check_test_key("one")
@@ -43,14 +43,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def cb():
- context_one.test_key = "one"
+ context_one.request = "one"
yield function()
self._check_test_key("one")
callback_completed[0] = True
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
# fire off function, but don't wait on it.
logcontext.preserve_fn(cb)()
@@ -107,7 +107,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = LoggingContext.current_context()
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
d1 = logcontext.make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
@@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase):
argument isn't actually a deferred"""
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
d1 = logcontext.make_deferred_yieldable("bum")
self._check_test_key("one")
|