diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index f7fad15c62..d996aa90bb 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -17,7 +17,8 @@ import logging
from twisted.internet import defer
-from synapse.types import get_domain_from_id
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id, UserID
from synapse.util.stringutils import random_string
@@ -33,7 +34,7 @@ class DeviceMessageHandler(object):
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
- self.is_mine_id = hs.is_mine_id
+ self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler(
@@ -52,6 +53,12 @@ class DeviceMessageHandler(object):
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
+ # we use UserID.from_string to catch invalid user ids
+ if not self.is_mine(UserID.from_string(user_id)):
+ logger.warning("Request for keys for non-local user %s",
+ user_id)
+ raise SynapseError(400, "Not a user here")
+
messages_by_device = {
device_id: {
"content": message_content,
@@ -77,7 +84,8 @@ class DeviceMessageHandler(object):
local_messages = {}
remote_messages = {}
for user_id, by_device in messages.items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
messages_by_device = {
device_id: {
"content": message_content,
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 668a90e495..5af8abf66b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -20,7 +20,7 @@ from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException
-from synapse.types import get_domain_from_id
+from synapse.types import get_domain_from_id, UserID
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination
@@ -32,7 +32,7 @@ class E2eKeysHandler(object):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.device_handler = hs.get_device_handler()
- self.is_mine_id = hs.is_mine_id
+ self.is_mine = hs.is_mine
self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the
@@ -70,7 +70,8 @@ class E2eKeysHandler(object):
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
else:
remote_queries[user_id] = device_ids
@@ -170,7 +171,8 @@ class E2eKeysHandler(object):
result_dict = {}
for user_id, device_ids in query.items():
- if not self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s",
user_id)
raise SynapseError(400, "Not a user here")
@@ -213,7 +215,8 @@ class E2eKeysHandler(object):
remote_queries = {}
for user_id, device_keys in query.get("one_time_keys", {}).items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 269b65ca41..165c684d0d 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -93,6 +93,8 @@ response_db_txn_count = metrics.register_counter(
),
)
+# seconds spent waiting for db txns, excluding scheduling time, when processing
+# this request
response_db_txn_duration = metrics.register_counter(
"response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
@@ -100,6 +102,10 @@ response_db_txn_duration = metrics.register_counter(
),
)
+# seconds spent waiting for a db connection, when processing this request
+response_db_sched_duration = metrics.register_counter(
+ "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
+)
_next_request_id = 0
@@ -316,15 +322,6 @@ class JsonResource(HttpServer, resource.Resource):
def _send_response(self, request, code, response_json_object,
response_code_message=None):
- # could alternatively use request.notifyFinish() and flip a flag when
- # the Deferred fires, but since the flag is RIGHT THERE it seems like
- # a waste.
- if request._disconnected:
- logger.warn(
- "Not sending response to request %s, already disconnected.",
- request)
- return
-
outgoing_responses_counter.inc(request.method, str(code))
# TODO: Only enable CORS for the requests that need it.
@@ -377,7 +374,10 @@ class RequestMetrics(object):
context.db_txn_count, request.method, self.name, tag
)
response_db_txn_duration.inc_by(
- context.db_txn_duration, request.method, self.name, tag
+ context.db_txn_duration_ms / 1000., request.method, self.name, tag
+ )
+ response_db_sched_duration.inc_by(
+ context.db_sched_duration_ms / 1000., request.method, self.name, tag
)
@@ -400,6 +400,15 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False,
version_string="", canonical_json=True):
+ # could alternatively use request.notifyFinish() and flip a flag when
+ # the Deferred fires, but since the flag is RIGHT THERE it seems like
+ # a waste.
+ if request._disconnected:
+ logger.warn(
+ "Not sending response to request %s, already disconnected.",
+ request)
+ return
+
if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n"
else:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index cd1492b1c3..e422c8dfae 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -66,14 +66,15 @@ class SynapseRequest(Request):
context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage()
db_txn_count = context.db_txn_count
- db_txn_duration = context.db_txn_duration
+ db_txn_duration_ms = context.db_txn_duration_ms
+ db_sched_duration_ms = context.db_sched_duration_ms
except Exception:
ru_utime, ru_stime = (0, 0)
- db_txn_count, db_txn_duration = (0, 0)
+ db_txn_count, db_txn_duration_ms = (0, 0)
self.site.access_logger.info(
"%s - %s - {%s}"
- " Processed request: %dms (%dms, %dms) (%dms/%d)"
+ " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
" %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
@@ -81,7 +82,8 @@ class SynapseRequest(Request):
int(time.time() * 1000) - self.start_time,
int(ru_utime * 1000),
int(ru_stime * 1000),
- int(db_txn_duration * 1000),
+ db_sched_duration_ms,
+ db_txn_duration_ms,
int(db_txn_count),
self.sentLength,
self.code,
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
index f480aae614..1e783e5ff4 100644
--- a/synapse/metrics/metric.py
+++ b/synapse/metrics/metric.py
@@ -15,6 +15,9 @@
from itertools import chain
+import logging
+
+logger = logging.getLogger(__name__)
def flatten(items):
@@ -153,7 +156,11 @@ class CallbackMetric(BaseMetric):
self.callback = callback
def render(self):
- value = self.callback()
+ try:
+ value = self.callback()
+ except Exception:
+ logger.exception("Failed to render %s", self.name)
+ return ["# FAILED to render " + self.name]
if self.is_scalar():
return list(self._render_for_labels([], value))
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 4163c9e416..4f56bcf577 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -51,7 +51,7 @@ import urlparse
logger = logging.getLogger(__name__)
-UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
+UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository(object):
@@ -73,6 +73,7 @@ class MediaRepository(object):
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set()
+ self.recently_accessed_locals = set()
# List of StorageProviders where we should search for media and
# potentially upload to.
@@ -93,19 +94,34 @@ class MediaRepository(object):
)
self.clock.looping_call(
- self._update_recently_accessed_remotes,
- UPDATE_RECENTLY_ACCESSED_REMOTES_TS
+ self._update_recently_accessed,
+ UPDATE_RECENTLY_ACCESSED_TS,
)
@defer.inlineCallbacks
- def _update_recently_accessed_remotes(self):
- media = self.recently_accessed_remotes
+ def _update_recently_accessed(self):
+ remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
+ local_media = self.recently_accessed_locals
+ self.recently_accessed_locals = set()
+
yield self.store.update_cached_last_access_time(
- media, self.clock.time_msec()
+ local_media, remote_media, self.clock.time_msec()
)
+ def mark_recently_accessed(self, server_name, media_id):
+ """Mark the given media as recently accessed.
+
+ Args:
+ server_name (str|None): Origin server of media, or None if local
+ media_id (str): The media ID of the content
+ """
+ if server_name:
+ self.recently_accessed_remotes.add((server_name, media_id))
+ else:
+ self.recently_accessed_locals.add(media_id)
+
@defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length,
auth_user):
@@ -167,6 +183,8 @@ class MediaRepository(object):
respond_404(request)
return
+ self.mark_recently_accessed(None, media_id)
+
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
@@ -198,7 +216,7 @@ class MediaRepository(object):
Deferred: Resolves once a response has successfully been written
to request
"""
- self.recently_accessed_remotes.add((server_name, media_id))
+ self.mark_recently_accessed(server_name, media_id)
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 8c9653843b..12e84a2b7c 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -67,6 +67,7 @@ class ThumbnailResource(Resource):
yield self._respond_local_thumbnail(
request, media_id, width, height, method, m_type
)
+ self.media_repo.mark_recently_accessed(None, media_id)
else:
if self.dynamic_thumbnails:
yield self._select_or_generate_remote_thumbnail(
@@ -78,6 +79,7 @@ class ThumbnailResource(Resource):
request, server_name, media_id,
width, height, method, m_type
)
+ self.media_repo.mark_recently_accessed(server_name, media_id)
@defer.inlineCallbacks
def _respond_local_thumbnail(self, request, media_id, width, height,
diff --git a/synapse/state.py b/synapse/state.py
index 9e624b4937..1f9abf9d3d 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -341,7 +341,7 @@ class StateHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
- new_state = yield resolve_events(
+ new_state = yield resolve_events_with_factory(
state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
@@ -404,7 +404,7 @@ class StateHandler(object):
}
with Measure(self.clock, "state._resolve_events"):
- new_state = resolve_events(state_set_ids, state_map)
+ new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
@@ -420,19 +420,17 @@ def _ordered_events(events):
return sorted(events, key=key_func)
-def resolve_events(state_sets, state_map_factory):
+def resolve_events_with_state_map(state_sets, state_map):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
- state_map_factory(dict|callable): If callable, then will be called
- with a list of event_ids that are needed, and should return with
- a Deferred of dict of event_id to event. Otherwise, should be
- a dict from event_id to event of all events in state_sets.
+ state_map(dict): a dict from event_id to event, for all events in
+ state_sets.
Returns
- dict[(str, str), synapse.events.FrozenEvent] is a map from
- (type, state_key) to event.
+ dict[(str, str), synapse.events.FrozenEvent]:
+ a map from (type, state_key) to event.
"""
if len(state_sets) == 1:
return state_sets[0]
@@ -441,13 +439,6 @@ def resolve_events(state_sets, state_map_factory):
state_sets,
)
- if callable(state_map_factory):
- return _resolve_with_state_fac(
- unconflicted_state, conflicted_state, state_map_factory
- )
-
- state_map = state_map_factory
-
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
@@ -491,8 +482,26 @@ def _seperate(state_sets):
@defer.inlineCallbacks
-def _resolve_with_state_fac(unconflicted_state, conflicted_state,
- state_map_factory):
+def resolve_events_with_factory(state_sets, state_map_factory):
+ """
+ Args:
+ state_sets(list): List of dicts of (type, state_key) -> event_id,
+ which are the different state groups to resolve.
+ state_map_factory(func): will be called
+ with a list of event_ids that are needed, and should return with
+ a Deferred of dict of event_id to event.
+
+ Returns
+ Deferred[dict[(str, str), synapse.events.FrozenEvent]]:
+ a map from (type, state_key) to event.
+ """
+ if len(state_sets) == 1:
+ defer.returnValue(state_sets[0])
+
+ unconflicted_state, conflicted_state = _seperate(
+ state_sets,
+ )
+
needed_events = set(
event_id
for event_ids in conflicted_state.itervalues()
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b971f0cb18..68125006eb 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -291,33 +291,33 @@ class SQLBaseStore(object):
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
- current_context = LoggingContext.current_context()
+ """Starts a transaction on the database and runs a given function
- start_time = time.time() * 1000
+ Arguments:
+ desc (str): description of the transaction, for logging and metrics
+ func (func): callback function, which will be called with a
+ database transaction (twisted.enterprise.adbapi.Transaction) as
+ its first argument, followed by `args` and `kwargs`.
+
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
+ current_context = LoggingContext.current_context()
after_callbacks = []
final_callbacks = []
def inner_func(conn, *args, **kwargs):
- with LoggingContext("runInteraction") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
-
- if self.database_engine.is_connection_closed(conn):
- logger.debug("Reconnecting closed database connection")
- conn.reconnect()
-
- current_context.copy_to(context)
- return self._new_transaction(
- conn, desc, after_callbacks, final_callbacks, current_context,
- func, *args, **kwargs
- )
+ return self._new_transaction(
+ conn, desc, after_callbacks, final_callbacks, current_context,
+ func, *args, **kwargs
+ )
try:
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(
- inner_func, *args, **kwargs
- )
+ result = yield self.runWithConnection(inner_func, *args, **kwargs)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
@@ -329,14 +329,27 @@ class SQLBaseStore(object):
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
+ """Wraps the .runWithConnection() method on the underlying db_pool.
+
+ Arguments:
+ func (func): callback function, which will be called with a
+ database connection (twisted.enterprise.adbapi.Connection) as
+ its first argument, followed by `args` and `kwargs`.
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+ sched_duration_ms = time.time() * 1000 - start_time
+ sql_scheduling_timer.inc_by(sched_duration_ms)
+ current_context.add_database_scheduled(sched_duration_ms)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index ad1d782705..7a9cd3ec90 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -27,7 +27,7 @@ from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
-from synapse.state import resolve_events
+from synapse.state import resolve_events_with_factory
from synapse.util.caches.descriptors import cached
from synapse.types import get_domain_from_id
@@ -535,6 +535,12 @@ class EventsStore(SQLBaseStore):
# the events we have yet to persist, so we need a slightly more
# complicated event lookup function than simply looking the events
# up in the db.
+
+ logger.info(
+ "Resolving state for %s with %i state sets",
+ room_id, len(state_sets),
+ )
+
events_map = {ev.event_id: ev for ev, _ in events_context}
@defer.inlineCallbacks
@@ -557,7 +563,7 @@ class EventsStore(SQLBaseStore):
to_return.update(evs)
defer.returnValue(to_return)
- current_state = yield resolve_events(
+ current_state = yield resolve_events_with_factory(
state_sets,
state_map_factory=get_events,
)
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 6ebc372498..e6cdbb0545 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -173,7 +173,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
desc="store_cached_remote_media",
)
- def update_cached_last_access_time(self, origin_id_tuples, time_ts):
+ def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+ """Updates the last access time of the given media
+
+ Args:
+ local_media (iterable[str]): Set of media_ids
+ remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+ time_ms: Current time in milliseconds
+ """
def update_cache_txn(txn):
sql = (
"UPDATE remote_media_cache SET last_access_ts = ?"
@@ -181,8 +188,18 @@ class MediaRepositoryStore(BackgroundUpdateStore):
)
txn.executemany(sql, (
- (time_ts, media_origin, media_id)
- for media_origin, media_id in origin_id_tuples
+ (time_ms, media_origin, media_id)
+ for media_origin, media_id in remote_media
+ ))
+
+ sql = (
+ "UPDATE local_media_repository SET last_access_ts = ?"
+ " WHERE media_id = ?"
+ )
+
+ txn.executemany(sql, (
+ (time_ms, media_id)
+ for media_id in local_media
))
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index d1691bbac2..c845a0cec5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 46
+SCHEMA_VERSION = 47
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/schema/delta/47/last_access_media.sql
new file mode 100644
index 0000000000..f505fb22b5
--- /dev/null
+++ b/synapse/storage/schema/delta/47/last_access_media.sql
@@ -0,0 +1,16 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT;
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index c9bff408ef..f150ef0103 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -641,8 +641,13 @@ class UserDirectoryStore(SQLBaseStore):
"""
if self.hs.config.user_directory_search_all_users:
- join_clause = ""
- where_clause = "?<>''" # naughty hack to keep the same number of binds
+ # dummy to keep the number of binds & aliases the same
+ join_clause = """
+ LEFT JOIN (
+ SELECT NULL as user_id WHERE NULL = ?
+ ) AS s USING (user_id)"
+ """
+ where_clause = ""
else:
join_clause = """
LEFT JOIN users_in_public_rooms AS p USING (user_id)
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
new file mode 100644
index 0000000000..90a2608d6f
--- /dev/null
+++ b/synapse/util/file_consumer.py
@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import threads, reactor
+
+from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+
+import Queue
+
+
+class BackgroundFileConsumer(object):
+ """A consumer that writes to a file like object. Supports both push
+ and pull producers
+
+ Args:
+ file_obj (file): The file like object to write to. Closed when
+ finished.
+ """
+
+ # For PushProducers pause if we have this many unwritten slices
+ _PAUSE_ON_QUEUE_SIZE = 5
+ # And resume once the size of the queue is less than this
+ _RESUME_ON_QUEUE_SIZE = 2
+
+ def __init__(self, file_obj):
+ self._file_obj = file_obj
+
+ # Producer we're registered with
+ self._producer = None
+
+ # True if PushProducer, false if PullProducer
+ self.streaming = False
+
+ # For PushProducers, indicates whether we've paused the producer and
+ # need to call resumeProducing before we get more data.
+ self._paused_producer = False
+
+ # Queue of slices of bytes to be written. When producer calls
+ # unregister a final None is sent.
+ self._bytes_queue = Queue.Queue()
+
+ # Deferred that is resolved when finished writing
+ self._finished_deferred = None
+
+ # If the _writer thread throws an exception it gets stored here.
+ self._write_exception = None
+
+ def registerProducer(self, producer, streaming):
+ """Part of IConsumer interface
+
+ Args:
+ producer (IProducer)
+ streaming (bool): True if push based producer, False if pull
+ based.
+ """
+ if self._producer:
+ raise Exception("registerProducer called twice")
+
+ self._producer = producer
+ self.streaming = streaming
+ self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
+ if not streaming:
+ self._producer.resumeProducing()
+
+ def unregisterProducer(self):
+ """Part of IProducer interface
+ """
+ self._producer = None
+ if not self._finished_deferred.called:
+ self._bytes_queue.put_nowait(None)
+
+ def write(self, bytes):
+ """Part of IProducer interface
+ """
+ if self._write_exception:
+ raise self._write_exception
+
+ if self._finished_deferred.called:
+ raise Exception("consumer has closed")
+
+ self._bytes_queue.put_nowait(bytes)
+
+ # If this is a PushProducer and the queue is getting behind
+ # then we pause the producer.
+ if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
+ self._paused_producer = True
+ self._producer.pauseProducing()
+
+ def _writer(self):
+ """This is run in a background thread to write to the file.
+ """
+ try:
+ while self._producer or not self._bytes_queue.empty():
+ # If we've paused the producer check if we should resume the
+ # producer.
+ if self._producer and self._paused_producer:
+ if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
+ reactor.callFromThread(self._resume_paused_producer)
+
+ bytes = self._bytes_queue.get()
+
+ # If we get a None (or empty list) then that's a signal used
+ # to indicate we should check if we should stop.
+ if bytes:
+ self._file_obj.write(bytes)
+
+ # If its a pull producer then we need to explicitly ask for
+ # more stuff.
+ if not self.streaming and self._producer:
+ reactor.callFromThread(self._producer.resumeProducing)
+ except Exception as e:
+ self._write_exception = e
+ raise
+ finally:
+ self._file_obj.close()
+
+ def wait(self):
+ """Returns a deferred that resolves when finished writing to file
+ """
+ return make_deferred_yieldable(self._finished_deferred)
+
+ def _resume_paused_producer(self):
+ """Gets called if we should resume producing after being paused
+ """
+ if self._paused_producer and self._producer:
+ self._paused_producer = False
+ self._producer.resumeProducing()
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 48c9f6802d..94fa7cac98 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -52,13 +52,17 @@ except Exception:
class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a
"with" block.
+
Args:
name (str): Name for the context for debugging.
"""
__slots__ = [
- "previous_context", "name", "usage_start", "usage_end", "main_thread",
- "__dict__", "tag", "alive",
+ "previous_context", "name", "ru_stime", "ru_utime",
+ "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+ "usage_start", "usage_end",
+ "main_thread", "alive",
+ "request", "tag",
]
thread_local = threading.local()
@@ -83,6 +87,9 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms):
pass
+ def add_database_scheduled(self, sched_ms):
+ pass
+
def __nonzero__(self):
return False
@@ -94,9 +101,17 @@ class LoggingContext(object):
self.ru_stime = 0.
self.ru_utime = 0.
self.db_txn_count = 0
- self.db_txn_duration = 0.
+
+ # ms spent waiting for db txns, excluding scheduling time
+ self.db_txn_duration_ms = 0
+
+ # ms spent waiting for db txns to be scheduled
+ self.db_sched_duration_ms = 0
+
self.usage_start = None
+ self.usage_end = None
self.main_thread = threading.current_thread()
+ self.request = None
self.tag = ""
self.alive = True
@@ -105,7 +120,11 @@ class LoggingContext(object):
@classmethod
def current_context(cls):
- """Get the current logging context from thread local storage"""
+ """Get the current logging context from thread local storage
+
+ Returns:
+ LoggingContext: the current logging context
+ """
return getattr(cls.thread_local, "current_context", cls.sentinel)
@classmethod
@@ -155,11 +174,13 @@ class LoggingContext(object):
self.alive = False
def copy_to(self, record):
- """Copy fields from this context to the record"""
- for key, value in self.__dict__.items():
- setattr(record, key, value)
+ """Copy logging fields from this context to a log record or
+ another LoggingContext
+ """
- record.ru_utime, record.ru_stime = self.get_resource_usage()
+ # 'request' is the only field we currently use in the logger, so that's
+ # all we need to copy
+ record.request = self.request
def start(self):
if threading.current_thread() is not self.main_thread:
@@ -194,7 +215,16 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms):
self.db_txn_count += 1
- self.db_txn_duration += duration_ms / 1000.
+ self.db_txn_duration_ms += duration_ms
+
+ def add_database_scheduled(self, sched_ms):
+ """Record a use of the database pool
+
+ Args:
+ sched_ms (int): number of milliseconds it took us to get a
+ connection
+ """
+ self.db_sched_duration_ms += sched_ms
class LoggingContextFilter(logging.Filter):
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 8d22ff3068..e4b5687a4b 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
# total number of times we have hit this block
-response_count = metrics.register_counter(
+block_counter = metrics.register_counter(
"block_count",
labels=["block_name"],
alternative_names=(
@@ -72,13 +72,19 @@ block_db_txn_count = metrics.register_counter(
),
)
+# seconds spent waiting for db txns, excluding scheduling time, in this block
block_db_txn_duration = metrics.register_counter(
"block_db_txn_duration_seconds", labels=["block_name"],
alternative_names=(
- metrics.name_prefix + "_block_db_txn_count:total",
+ metrics.name_prefix + "_block_db_txn_duration:total",
),
)
+# seconds spent waiting for a db connection, in this block
+block_db_sched_duration = metrics.register_counter(
+ "block_db_sched_duration_seconds", labels=["block_name"],
+)
+
def measure_func(name):
def wrapper(func):
@@ -95,7 +101,9 @@ def measure_func(name):
class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
- "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
+ "ru_stime",
+ "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+ "created_context",
]
def __init__(self, clock, name):
@@ -115,13 +123,16 @@ class Measure(object):
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
self.db_txn_count = self.start_context.db_txn_count
- self.db_txn_duration = self.start_context.db_txn_duration
+ self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
+ self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(exc_type, Exception) or not self.start_context:
return
duration = self.clock.time_msec() - self.start
+
+ block_counter.inc(self.name)
block_timer.inc_by(duration, self.name)
context = LoggingContext.current_context()
@@ -145,7 +156,12 @@ class Measure(object):
context.db_txn_count - self.db_txn_count, self.name
)
block_db_txn_duration.inc_by(
- context.db_txn_duration - self.db_txn_duration, self.name
+ (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
+ self.name
+ )
+ block_db_sched_duration.inc_by(
+ (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
+ self.name
)
if self.created_context:
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 570312da84..c899fecf5d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -68,7 +68,7 @@ class KeyringTestCase(unittest.TestCase):
def check_context(self, _, expected):
self.assertEquals(
- getattr(LoggingContext.current_context(), "test_key", None),
+ getattr(LoggingContext.current_context(), "request", None),
expected
)
@@ -82,7 +82,7 @@ class KeyringTestCase(unittest.TestCase):
lookup_2_deferred = defer.Deferred()
with LoggingContext("one") as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
wait_1_deferred = kr.wait_for_previous_lookups(
["server1"],
@@ -96,7 +96,7 @@ class KeyringTestCase(unittest.TestCase):
wait_1_deferred.addBoth(self.check_context, "one")
with LoggingContext("two") as context_two:
- context_two.test_key = "two"
+ context_two.request = "two"
# set off another wait. It should block because the first lookup
# hasn't yet completed.
@@ -137,7 +137,7 @@ class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def get_perspectives(**kwargs):
self.assertEquals(
- LoggingContext.current_context().test_key, "11",
+ LoggingContext.current_context().request, "11",
)
with logcontext.PreserveLoggingContext():
yield persp_deferred
@@ -145,7 +145,7 @@ class KeyringTestCase(unittest.TestCase):
self.http_client.post_json.side_effect = get_perspectives
with LoggingContext("11") as context_11:
- context_11.test_key = "11"
+ context_11.request = "11"
# start off a first set of lookups
res_deferreds = kr.verify_json_objects_for_server(
@@ -173,7 +173,7 @@ class KeyringTestCase(unittest.TestCase):
self.assertIs(LoggingContext.current_context(), context_11)
context_12 = LoggingContext("12")
- context_12.test_key = "12"
+ context_12.request = "12"
with logcontext.PreserveLoggingContext(context_12):
# a second request for a server with outstanding requests
# should block rather than start a second call
@@ -211,7 +211,7 @@ class KeyringTestCase(unittest.TestCase):
sentinel_context = LoggingContext.current_context()
with LoggingContext("one") as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
defer = kr.verify_json_for_server("server9", {})
try:
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
new file mode 100644
index 0000000000..76e2234255
--- /dev/null
+++ b/tests/util/test_file_consumer.py
@@ -0,0 +1,176 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer, reactor
+from mock import NonCallableMock
+
+from synapse.util.file_consumer import BackgroundFileConsumer
+
+from tests import unittest
+from StringIO import StringIO
+
+import threading
+
+
+class FileConsumerTests(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def test_pull_consumer(self):
+ string_file = StringIO()
+ consumer = BackgroundFileConsumer(string_file)
+
+ try:
+ producer = DummyPullProducer()
+
+ yield producer.register_with_consumer(consumer)
+
+ yield producer.write_and_wait("Foo")
+
+ self.assertEqual(string_file.getvalue(), "Foo")
+
+ yield producer.write_and_wait("Bar")
+
+ self.assertEqual(string_file.getvalue(), "FooBar")
+ finally:
+ consumer.unregisterProducer()
+
+ yield consumer.wait()
+
+ self.assertTrue(string_file.closed)
+
+ @defer.inlineCallbacks
+ def test_push_consumer(self):
+ string_file = BlockingStringWrite()
+ consumer = BackgroundFileConsumer(string_file)
+
+ try:
+ producer = NonCallableMock(spec_set=[])
+
+ consumer.registerProducer(producer, True)
+
+ consumer.write("Foo")
+ yield string_file.wait_for_n_writes(1)
+
+ self.assertEqual(string_file.buffer, "Foo")
+
+ consumer.write("Bar")
+ yield string_file.wait_for_n_writes(2)
+
+ self.assertEqual(string_file.buffer, "FooBar")
+ finally:
+ consumer.unregisterProducer()
+
+ yield consumer.wait()
+
+ self.assertTrue(string_file.closed)
+
+ @defer.inlineCallbacks
+ def test_push_producer_feedback(self):
+ string_file = BlockingStringWrite()
+ consumer = BackgroundFileConsumer(string_file)
+
+ try:
+ producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
+
+ resume_deferred = defer.Deferred()
+ producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None)
+
+ consumer.registerProducer(producer, True)
+
+ number_writes = 0
+ with string_file.write_lock:
+ for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
+ consumer.write("Foo")
+ number_writes += 1
+
+ producer.pauseProducing.assert_called_once()
+
+ yield string_file.wait_for_n_writes(number_writes)
+
+ yield resume_deferred
+ producer.resumeProducing.assert_called_once()
+ finally:
+ consumer.unregisterProducer()
+
+ yield consumer.wait()
+
+ self.assertTrue(string_file.closed)
+
+
+class DummyPullProducer(object):
+ def __init__(self):
+ self.consumer = None
+ self.deferred = defer.Deferred()
+
+ def resumeProducing(self):
+ d = self.deferred
+ self.deferred = defer.Deferred()
+ d.callback(None)
+
+ def write_and_wait(self, bytes):
+ d = self.deferred
+ self.consumer.write(bytes)
+ return d
+
+ def register_with_consumer(self, consumer):
+ d = self.deferred
+ self.consumer = consumer
+ self.consumer.registerProducer(self, False)
+ return d
+
+
+class BlockingStringWrite(object):
+ def __init__(self):
+ self.buffer = ""
+ self.closed = False
+ self.write_lock = threading.Lock()
+
+ self._notify_write_deferred = None
+ self._number_of_writes = 0
+
+ def write(self, bytes):
+ with self.write_lock:
+ self.buffer += bytes
+ self._number_of_writes += 1
+
+ reactor.callFromThread(self._notify_write)
+
+ def close(self):
+ self.closed = True
+
+ def _notify_write(self):
+ "Called by write to indicate a write happened"
+ with self.write_lock:
+ if not self._notify_write_deferred:
+ return
+ d = self._notify_write_deferred
+ self._notify_write_deferred = None
+ d.callback(None)
+
+ @defer.inlineCallbacks
+ def wait_for_n_writes(self, n):
+ "Wait for n writes to have happened"
+ while True:
+ with self.write_lock:
+ if n <= self._number_of_writes:
+ return
+
+ if not self._notify_write_deferred:
+ self._notify_write_deferred = defer.Deferred()
+
+ d = self._notify_write_deferred
+
+ yield d
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index e2f7765f49..4850722bc5 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -12,12 +12,12 @@ class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value):
self.assertEquals(
- LoggingContext.current_context().test_key, value
+ LoggingContext.current_context().request, value
)
def test_with_context(self):
with LoggingContext() as context_one:
- context_one.test_key = "test"
+ context_one.request = "test"
self._check_test_key("test")
@defer.inlineCallbacks
@@ -25,14 +25,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def competing_callback():
with LoggingContext() as competing_context:
- competing_context.test_key = "competing"
+ competing_context.request = "competing"
yield sleep(0)
self._check_test_key("competing")
reactor.callLater(0, competing_callback)
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
yield sleep(0)
self._check_test_key("one")
@@ -43,14 +43,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def cb():
- context_one.test_key = "one"
+ context_one.request = "one"
yield function()
self._check_test_key("one")
callback_completed[0] = True
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
# fire off function, but don't wait on it.
logcontext.preserve_fn(cb)()
@@ -107,7 +107,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = LoggingContext.current_context()
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
d1 = logcontext.make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
@@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase):
argument isn't actually a deferred"""
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
d1 = logcontext.make_deferred_yieldable("bum")
self._check_test_key("one")
|