diff options
-rw-r--r-- | synapse/handlers/devicemessage.py | 14 | ||||
-rw-r--r-- | synapse/handlers/e2e_keys.py | 13 | ||||
-rw-r--r-- | synapse/http/server.py | 29 | ||||
-rw-r--r-- | synapse/http/site.py | 10 | ||||
-rw-r--r-- | synapse/state.py | 45 | ||||
-rw-r--r-- | synapse/storage/_base.py | 55 | ||||
-rw-r--r-- | synapse/storage/events.py | 4 | ||||
-rw-r--r-- | synapse/storage/user_directory.py | 9 | ||||
-rw-r--r-- | synapse/util/logcontext.py | 48 | ||||
-rw-r--r-- | synapse/util/metrics.py | 20 | ||||
-rw-r--r-- | tests/crypto/test_keyring.py | 14 | ||||
-rw-r--r-- | tests/util/test_logcontext.py | 16 |
12 files changed, 185 insertions, 92 deletions
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index f7fad15c62..d996aa90bb 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -17,7 +17,8 @@ import logging from twisted.internet import defer -from synapse.types import get_domain_from_id +from synapse.api.errors import SynapseError +from synapse.types import get_domain_from_id, UserID from synapse.util.stringutils import random_string @@ -33,7 +34,7 @@ class DeviceMessageHandler(object): """ self.store = hs.get_datastore() self.notifier = hs.get_notifier() - self.is_mine_id = hs.is_mine_id + self.is_mine = hs.is_mine self.federation = hs.get_federation_sender() hs.get_replication_layer().register_edu_handler( @@ -52,6 +53,12 @@ class DeviceMessageHandler(object): message_type = content["type"] message_id = content["message_id"] for user_id, by_device in content["messages"].items(): + # we use UserID.from_string to catch invalid user ids + if not self.is_mine(UserID.from_string(user_id)): + logger.warning("Request for keys for non-local user %s", + user_id) + raise SynapseError(400, "Not a user here") + messages_by_device = { device_id: { "content": message_content, @@ -77,7 +84,8 @@ class DeviceMessageHandler(object): local_messages = {} remote_messages = {} for user_id, by_device in messages.items(): - if self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if self.is_mine(UserID.from_string(user_id)): messages_by_device = { device_id: { "content": message_content, diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 668a90e495..5af8abf66b 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -20,7 +20,7 @@ from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.api.errors import SynapseError, CodeMessageException -from synapse.types import get_domain_from_id +from synapse.types import get_domain_from_id, UserID from synapse.util.logcontext import preserve_fn, make_deferred_yieldable from synapse.util.retryutils import NotRetryingDestination @@ -32,7 +32,7 @@ class E2eKeysHandler(object): self.store = hs.get_datastore() self.federation = hs.get_replication_layer() self.device_handler = hs.get_device_handler() - self.is_mine_id = hs.is_mine_id + self.is_mine = hs.is_mine self.clock = hs.get_clock() # doesn't really work as part of the generic query API, because the @@ -70,7 +70,8 @@ class E2eKeysHandler(object): remote_queries = {} for user_id, device_ids in device_keys_query.items(): - if self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if self.is_mine(UserID.from_string(user_id)): local_query[user_id] = device_ids else: remote_queries[user_id] = device_ids @@ -170,7 +171,8 @@ class E2eKeysHandler(object): result_dict = {} for user_id, device_ids in query.items(): - if not self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if not self.is_mine(UserID.from_string(user_id)): logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") @@ -213,7 +215,8 @@ class E2eKeysHandler(object): remote_queries = {} for user_id, device_keys in query.get("one_time_keys", {}).items(): - if self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if self.is_mine(UserID.from_string(user_id)): for device_id, algorithm in device_keys.items(): local_query.append((user_id, device_id, algorithm)) else: diff --git a/synapse/http/server.py b/synapse/http/server.py index 269b65ca41..165c684d0d 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -93,6 +93,8 @@ response_db_txn_count = metrics.register_counter( ), ) +# seconds spent waiting for db txns, excluding scheduling time, when processing +# this request response_db_txn_duration = metrics.register_counter( "response_db_txn_duration_seconds", labels=["method", "servlet", "tag"], alternative_names=( @@ -100,6 +102,10 @@ response_db_txn_duration = metrics.register_counter( ), ) +# seconds spent waiting for a db connection, when processing this request +response_db_sched_duration = metrics.register_counter( + "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"] +) _next_request_id = 0 @@ -316,15 +322,6 @@ class JsonResource(HttpServer, resource.Resource): def _send_response(self, request, code, response_json_object, response_code_message=None): - # could alternatively use request.notifyFinish() and flip a flag when - # the Deferred fires, but since the flag is RIGHT THERE it seems like - # a waste. - if request._disconnected: - logger.warn( - "Not sending response to request %s, already disconnected.", - request) - return - outgoing_responses_counter.inc(request.method, str(code)) # TODO: Only enable CORS for the requests that need it. @@ -377,7 +374,10 @@ class RequestMetrics(object): context.db_txn_count, request.method, self.name, tag ) response_db_txn_duration.inc_by( - context.db_txn_duration, request.method, self.name, tag + context.db_txn_duration_ms / 1000., request.method, self.name, tag + ) + response_db_sched_duration.inc_by( + context.db_sched_duration_ms / 1000., request.method, self.name, tag ) @@ -400,6 +400,15 @@ class RootRedirect(resource.Resource): def respond_with_json(request, code, json_object, send_cors=False, response_code_message=None, pretty_print=False, version_string="", canonical_json=True): + # could alternatively use request.notifyFinish() and flip a flag when + # the Deferred fires, but since the flag is RIGHT THERE it seems like + # a waste. + if request._disconnected: + logger.warn( + "Not sending response to request %s, already disconnected.", + request) + return + if pretty_print: json_bytes = encode_pretty_printed_json(json_object) + "\n" else: diff --git a/synapse/http/site.py b/synapse/http/site.py index cd1492b1c3..e422c8dfae 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -66,14 +66,15 @@ class SynapseRequest(Request): context = LoggingContext.current_context() ru_utime, ru_stime = context.get_resource_usage() db_txn_count = context.db_txn_count - db_txn_duration = context.db_txn_duration + db_txn_duration_ms = context.db_txn_duration_ms + db_sched_duration_ms = context.db_sched_duration_ms except Exception: ru_utime, ru_stime = (0, 0) - db_txn_count, db_txn_duration = (0, 0) + db_txn_count, db_txn_duration_ms = (0, 0) self.site.access_logger.info( "%s - %s - {%s}" - " Processed request: %dms (%dms, %dms) (%dms/%d)" + " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)" " %sB %s \"%s %s %s\" \"%s\"", self.getClientIP(), self.site.site_tag, @@ -81,7 +82,8 @@ class SynapseRequest(Request): int(time.time() * 1000) - self.start_time, int(ru_utime * 1000), int(ru_stime * 1000), - int(db_txn_duration * 1000), + db_sched_duration_ms, + db_txn_duration_ms, int(db_txn_count), self.sentLength, self.code, diff --git a/synapse/state.py b/synapse/state.py index 9e624b4937..1f9abf9d3d 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -341,7 +341,7 @@ class StateHandler(object): if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events( + new_state = yield resolve_events_with_factory( state_groups_ids.values(), state_map_factory=lambda ev_ids: self.store.get_events( ev_ids, get_prev_content=False, check_redacted=False, @@ -404,7 +404,7 @@ class StateHandler(object): } with Measure(self.clock, "state._resolve_events"): - new_state = resolve_events(state_set_ids, state_map) + new_state = resolve_events_with_state_map(state_set_ids, state_map) new_state = { key: state_map[ev_id] for key, ev_id in new_state.items() @@ -420,19 +420,17 @@ def _ordered_events(events): return sorted(events, key=key_func) -def resolve_events(state_sets, state_map_factory): +def resolve_events_with_state_map(state_sets, state_map): """ Args: state_sets(list): List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - state_map_factory(dict|callable): If callable, then will be called - with a list of event_ids that are needed, and should return with - a Deferred of dict of event_id to event. Otherwise, should be - a dict from event_id to event of all events in state_sets. + state_map(dict): a dict from event_id to event, for all events in + state_sets. Returns - dict[(str, str), synapse.events.FrozenEvent] is a map from - (type, state_key) to event. + dict[(str, str), synapse.events.FrozenEvent]: + a map from (type, state_key) to event. """ if len(state_sets) == 1: return state_sets[0] @@ -441,13 +439,6 @@ def resolve_events(state_sets, state_map_factory): state_sets, ) - if callable(state_map_factory): - return _resolve_with_state_fac( - unconflicted_state, conflicted_state, state_map_factory - ) - - state_map = state_map_factory - auth_events = _create_auth_events_from_maps( unconflicted_state, conflicted_state, state_map ) @@ -491,8 +482,26 @@ def _seperate(state_sets): @defer.inlineCallbacks -def _resolve_with_state_fac(unconflicted_state, conflicted_state, - state_map_factory): +def resolve_events_with_factory(state_sets, state_map_factory): + """ + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + state_map_factory(func): will be called + with a list of event_ids that are needed, and should return with + a Deferred of dict of event_id to event. + + Returns + Deferred[dict[(str, str), synapse.events.FrozenEvent]]: + a map from (type, state_key) to event. + """ + if len(state_sets) == 1: + defer.returnValue(state_sets[0]) + + unconflicted_state, conflicted_state = _seperate( + state_sets, + ) + needed_events = set( event_id for event_ids in conflicted_state.itervalues() diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b971f0cb18..68125006eb 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -291,33 +291,33 @@ class SQLBaseStore(object): @defer.inlineCallbacks def runInteraction(self, desc, func, *args, **kwargs): - """Wraps the .runInteraction() method on the underlying db_pool.""" - current_context = LoggingContext.current_context() + """Starts a transaction on the database and runs a given function - start_time = time.time() * 1000 + Arguments: + desc (str): description of the transaction, for logging and metrics + func (func): callback function, which will be called with a + database transaction (twisted.enterprise.adbapi.Transaction) as + its first argument, followed by `args` and `kwargs`. + + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ + current_context = LoggingContext.current_context() after_callbacks = [] final_callbacks = [] def inner_func(conn, *args, **kwargs): - with LoggingContext("runInteraction") as context: - sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) - - if self.database_engine.is_connection_closed(conn): - logger.debug("Reconnecting closed database connection") - conn.reconnect() - - current_context.copy_to(context) - return self._new_transaction( - conn, desc, after_callbacks, final_callbacks, current_context, - func, *args, **kwargs - ) + return self._new_transaction( + conn, desc, after_callbacks, final_callbacks, current_context, + func, *args, **kwargs + ) try: - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection( - inner_func, *args, **kwargs - ) + result = yield self.runWithConnection(inner_func, *args, **kwargs) for after_callback, after_args, after_kwargs in after_callbacks: after_callback(*after_args, **after_kwargs) @@ -329,14 +329,27 @@ class SQLBaseStore(object): @defer.inlineCallbacks def runWithConnection(self, func, *args, **kwargs): - """Wraps the .runInteraction() method on the underlying db_pool.""" + """Wraps the .runWithConnection() method on the underlying db_pool. + + Arguments: + func (func): callback function, which will be called with a + database connection (twisted.enterprise.adbapi.Connection) as + its first argument, followed by `args` and `kwargs`. + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ current_context = LoggingContext.current_context() start_time = time.time() * 1000 def inner_func(conn, *args, **kwargs): with LoggingContext("runWithConnection") as context: - sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) + sched_duration_ms = time.time() * 1000 - start_time + sql_scheduling_timer.inc_by(sched_duration_ms) + current_context.add_database_scheduled(sched_duration_ms) if self.database_engine.is_connection_closed(conn): logger.debug("Reconnecting closed database connection") diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ad1d782705..c5292a5311 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -27,7 +27,7 @@ from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError -from synapse.state import resolve_events +from synapse.state import resolve_events_with_factory from synapse.util.caches.descriptors import cached from synapse.types import get_domain_from_id @@ -557,7 +557,7 @@ class EventsStore(SQLBaseStore): to_return.update(evs) defer.returnValue(to_return) - current_state = yield resolve_events( + current_state = yield resolve_events_with_factory( state_sets, state_map_factory=get_events, ) diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index c9bff408ef..f150ef0103 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -641,8 +641,13 @@ class UserDirectoryStore(SQLBaseStore): """ if self.hs.config.user_directory_search_all_users: - join_clause = "" - where_clause = "?<>''" # naughty hack to keep the same number of binds + # dummy to keep the number of binds & aliases the same + join_clause = """ + LEFT JOIN ( + SELECT NULL as user_id WHERE NULL = ? + ) AS s USING (user_id)" + """ + where_clause = "" else: join_clause = """ LEFT JOIN users_in_public_rooms AS p USING (user_id) diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 48c9f6802d..94fa7cac98 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -52,13 +52,17 @@ except Exception: class LoggingContext(object): """Additional context for log formatting. Contexts are scoped within a "with" block. + Args: name (str): Name for the context for debugging. """ __slots__ = [ - "previous_context", "name", "usage_start", "usage_end", "main_thread", - "__dict__", "tag", "alive", + "previous_context", "name", "ru_stime", "ru_utime", + "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms", + "usage_start", "usage_end", + "main_thread", "alive", + "request", "tag", ] thread_local = threading.local() @@ -83,6 +87,9 @@ class LoggingContext(object): def add_database_transaction(self, duration_ms): pass + def add_database_scheduled(self, sched_ms): + pass + def __nonzero__(self): return False @@ -94,9 +101,17 @@ class LoggingContext(object): self.ru_stime = 0. self.ru_utime = 0. self.db_txn_count = 0 - self.db_txn_duration = 0. + + # ms spent waiting for db txns, excluding scheduling time + self.db_txn_duration_ms = 0 + + # ms spent waiting for db txns to be scheduled + self.db_sched_duration_ms = 0 + self.usage_start = None + self.usage_end = None self.main_thread = threading.current_thread() + self.request = None self.tag = "" self.alive = True @@ -105,7 +120,11 @@ class LoggingContext(object): @classmethod def current_context(cls): - """Get the current logging context from thread local storage""" + """Get the current logging context from thread local storage + + Returns: + LoggingContext: the current logging context + """ return getattr(cls.thread_local, "current_context", cls.sentinel) @classmethod @@ -155,11 +174,13 @@ class LoggingContext(object): self.alive = False def copy_to(self, record): - """Copy fields from this context to the record""" - for key, value in self.__dict__.items(): - setattr(record, key, value) + """Copy logging fields from this context to a log record or + another LoggingContext + """ - record.ru_utime, record.ru_stime = self.get_resource_usage() + # 'request' is the only field we currently use in the logger, so that's + # all we need to copy + record.request = self.request def start(self): if threading.current_thread() is not self.main_thread: @@ -194,7 +215,16 @@ class LoggingContext(object): def add_database_transaction(self, duration_ms): self.db_txn_count += 1 - self.db_txn_duration += duration_ms / 1000. + self.db_txn_duration_ms += duration_ms + + def add_database_scheduled(self, sched_ms): + """Record a use of the database pool + + Args: + sched_ms (int): number of milliseconds it took us to get a + connection + """ + self.db_sched_duration_ms += sched_ms class LoggingContextFilter(logging.Filter): diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 8d22ff3068..059bb7fedf 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -72,6 +72,7 @@ block_db_txn_count = metrics.register_counter( ), ) +# seconds spent waiting for db txns, excluding scheduling time, in this block block_db_txn_duration = metrics.register_counter( "block_db_txn_duration_seconds", labels=["block_name"], alternative_names=( @@ -79,6 +80,11 @@ block_db_txn_duration = metrics.register_counter( ), ) +# seconds spent waiting for a db connection, in this block +block_db_sched_duration = metrics.register_counter( + "block_db_sched_duration_seconds", labels=["block_name"], +) + def measure_func(name): def wrapper(func): @@ -95,7 +101,9 @@ def measure_func(name): class Measure(object): __slots__ = [ "clock", "name", "start_context", "start", "new_context", "ru_utime", - "ru_stime", "db_txn_count", "db_txn_duration", "created_context" + "ru_stime", + "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms", + "created_context", ] def __init__(self, clock, name): @@ -115,7 +123,8 @@ class Measure(object): self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() self.db_txn_count = self.start_context.db_txn_count - self.db_txn_duration = self.start_context.db_txn_duration + self.db_txn_duration_ms = self.start_context.db_txn_duration_ms + self.db_sched_duration_ms = self.start_context.db_sched_duration_ms def __exit__(self, exc_type, exc_val, exc_tb): if isinstance(exc_type, Exception) or not self.start_context: @@ -145,7 +154,12 @@ class Measure(object): context.db_txn_count - self.db_txn_count, self.name ) block_db_txn_duration.inc_by( - context.db_txn_duration - self.db_txn_duration, self.name + (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000., + self.name + ) + block_db_sched_duration.inc_by( + (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000., + self.name ) if self.created_context: diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 570312da84..c899fecf5d 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -68,7 +68,7 @@ class KeyringTestCase(unittest.TestCase): def check_context(self, _, expected): self.assertEquals( - getattr(LoggingContext.current_context(), "test_key", None), + getattr(LoggingContext.current_context(), "request", None), expected ) @@ -82,7 +82,7 @@ class KeyringTestCase(unittest.TestCase): lookup_2_deferred = defer.Deferred() with LoggingContext("one") as context_one: - context_one.test_key = "one" + context_one.request = "one" wait_1_deferred = kr.wait_for_previous_lookups( ["server1"], @@ -96,7 +96,7 @@ class KeyringTestCase(unittest.TestCase): wait_1_deferred.addBoth(self.check_context, "one") with LoggingContext("two") as context_two: - context_two.test_key = "two" + context_two.request = "two" # set off another wait. It should block because the first lookup # hasn't yet completed. @@ -137,7 +137,7 @@ class KeyringTestCase(unittest.TestCase): @defer.inlineCallbacks def get_perspectives(**kwargs): self.assertEquals( - LoggingContext.current_context().test_key, "11", + LoggingContext.current_context().request, "11", ) with logcontext.PreserveLoggingContext(): yield persp_deferred @@ -145,7 +145,7 @@ class KeyringTestCase(unittest.TestCase): self.http_client.post_json.side_effect = get_perspectives with LoggingContext("11") as context_11: - context_11.test_key = "11" + context_11.request = "11" # start off a first set of lookups res_deferreds = kr.verify_json_objects_for_server( @@ -173,7 +173,7 @@ class KeyringTestCase(unittest.TestCase): self.assertIs(LoggingContext.current_context(), context_11) context_12 = LoggingContext("12") - context_12.test_key = "12" + context_12.request = "12" with logcontext.PreserveLoggingContext(context_12): # a second request for a server with outstanding requests # should block rather than start a second call @@ -211,7 +211,7 @@ class KeyringTestCase(unittest.TestCase): sentinel_context = LoggingContext.current_context() with LoggingContext("one") as context_one: - context_one.test_key = "one" + context_one.request = "one" defer = kr.verify_json_for_server("server9", {}) try: diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index e2f7765f49..4850722bc5 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -12,12 +12,12 @@ class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): self.assertEquals( - LoggingContext.current_context().test_key, value + LoggingContext.current_context().request, value ) def test_with_context(self): with LoggingContext() as context_one: - context_one.test_key = "test" + context_one.request = "test" self._check_test_key("test") @defer.inlineCallbacks @@ -25,14 +25,14 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def competing_callback(): with LoggingContext() as competing_context: - competing_context.test_key = "competing" + competing_context.request = "competing" yield sleep(0) self._check_test_key("competing") reactor.callLater(0, competing_callback) with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" yield sleep(0) self._check_test_key("one") @@ -43,14 +43,14 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def cb(): - context_one.test_key = "one" + context_one.request = "one" yield function() self._check_test_key("one") callback_completed[0] = True with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" # fire off function, but don't wait on it. logcontext.preserve_fn(cb)() @@ -107,7 +107,7 @@ class LoggingContextTestCase(unittest.TestCase): sentinel_context = LoggingContext.current_context() with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" d1 = logcontext.make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable @@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase): argument isn't actually a deferred""" with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" d1 = logcontext.make_deferred_yieldable("bum") self._check_test_key("one") |