From 39f4e29d0151b56a3c8528e3149cd5765b9f600d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 15 Jan 2018 17:00:12 +0000 Subject: Reorganise request and block metrics In order to circumvent the number of duplicate foo:count metrics increasing without bounds, it's time for a rearrangement. The following are all deprecated, and replaced with synapse_util_metrics_block_count: synapse_util_metrics_block_timer:count synapse_util_metrics_block_ru_utime:count synapse_util_metrics_block_ru_stime:count synapse_util_metrics_block_db_txn_count:count synapse_util_metrics_block_db_txn_duration:count The following are all deprecated, and replaced with synapse_http_server_response_count: synapse_http_server_requests synapse_http_server_response_time:count synapse_http_server_response_ru_utime:count synapse_http_server_response_ru_stime:count synapse_http_server_response_db_txn_count:count synapse_http_server_response_db_txn_duration:count The following are renamed (the old metrics are kept for now, but deprecated): synapse_util_metrics_block_timer:total -> synapse_util_metrics_block_time_seconds synapse_util_metrics_block_ru_utime:total -> synapse_util_metrics_block_ru_utime_seconds synapse_util_metrics_block_ru_stime:total -> synapse_util_metrics_block_ru_stime_seconds synapse_util_metrics_block_db_txn_count:total -> synapse_util_metrics_block_db_txn_count synapse_util_metrics_block_db_txn_duration:total -> synapse_util_metrics_block_db_txn_duration_seconds synapse_http_server_response_time:total -> synapse_http_server_response_time_seconds synapse_http_server_response_ru_utime:total -> synapse_http_server_response_ru_utime_seconds synapse_http_server_response_ru_stime:total -> synapse_http_server_response_ru_stime_seconds synapse_http_server_response_db_txn_count:total -> synapse_http_server_response_db_txn_count synapse_http_server_response_db_txn_duration:total synapse_http_server_response_db_txn_duration_seconds --- synapse/util/metrics.py | 53 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 42 insertions(+), 11 deletions(-) (limited to 'synapse/util') diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 4ea930d3e8..8d22ff3068 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -27,25 +27,56 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) -block_timer = metrics.register_distribution( - "block_timer", - labels=["block_name"] +# total number of times we have hit this block +response_count = metrics.register_counter( + "block_count", + labels=["block_name"], + alternative_names=( + # the following are all deprecated aliases for the same metric + metrics.name_prefix + x for x in ( + "_block_timer:count", + "_block_ru_utime:count", + "_block_ru_stime:count", + "_block_db_txn_count:count", + "_block_db_txn_duration:count", + ) + ) +) + +block_timer = metrics.register_counter( + "block_time_seconds", + labels=["block_name"], + alternative_names=( + metrics.name_prefix + "_block_timer:total", + ), ) -block_ru_utime = metrics.register_distribution( - "block_ru_utime", labels=["block_name"] +block_ru_utime = metrics.register_counter( + "block_ru_utime_seconds", labels=["block_name"], + alternative_names=( + metrics.name_prefix + "_block_ru_utime:total", + ), ) -block_ru_stime = metrics.register_distribution( - "block_ru_stime", labels=["block_name"] +block_ru_stime = metrics.register_counter( + "block_ru_stime_seconds", labels=["block_name"], + alternative_names=( + metrics.name_prefix + "_block_ru_stime:total", + ), ) -block_db_txn_count = metrics.register_distribution( - "block_db_txn_count", labels=["block_name"] +block_db_txn_count = metrics.register_counter( + "block_db_txn_count", labels=["block_name"], + alternative_names=( + metrics.name_prefix + "_block_db_txn_count:total", + ), ) -block_db_txn_duration = metrics.register_distribution( - "block_db_txn_duration", labels=["block_name"] +block_db_txn_duration = metrics.register_counter( + "block_db_txn_duration_seconds", labels=["block_name"], + alternative_names=( + metrics.name_prefix + "_block_db_txn_count:total", + ), ) -- cgit 1.5.1 From 44a498418c62a835aae9bff8550f844888b3ab84 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 11 Jan 2018 22:40:51 +0000 Subject: Optimise LoggingContext creation and copying It turns out that the only thing we use the __dict__ of LoggingContext for is `request`, and given we create lots of LoggingContexts and then copy them every time we do a db transaction or log line, using the __dict__ seems a bit redundant. Let's try to optimise things by making the request attribute explicit. --- synapse/util/logcontext.py | 25 ++++++++++++++++++------- tests/crypto/test_keyring.py | 14 +++++++------- tests/util/test_logcontext.py | 16 ++++++++-------- 3 files changed, 33 insertions(+), 22 deletions(-) (limited to 'synapse/util') diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 48c9f6802d..ca71a1fc27 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -52,13 +52,16 @@ except Exception: class LoggingContext(object): """Additional context for log formatting. Contexts are scoped within a "with" block. + Args: name (str): Name for the context for debugging. """ __slots__ = [ - "previous_context", "name", "usage_start", "usage_end", "main_thread", - "__dict__", "tag", "alive", + "previous_context", "name", "ru_stime", "ru_utime", + "db_txn_count", "db_txn_duration", "usage_start", "usage_end", + "main_thread", "alive", + "request", "tag", ] thread_local = threading.local() @@ -96,7 +99,9 @@ class LoggingContext(object): self.db_txn_count = 0 self.db_txn_duration = 0. self.usage_start = None + self.usage_end = None self.main_thread = threading.current_thread() + self.request = None self.tag = "" self.alive = True @@ -105,7 +110,11 @@ class LoggingContext(object): @classmethod def current_context(cls): - """Get the current logging context from thread local storage""" + """Get the current logging context from thread local storage + + Returns: + LoggingContext: the current logging context + """ return getattr(cls.thread_local, "current_context", cls.sentinel) @classmethod @@ -155,11 +164,13 @@ class LoggingContext(object): self.alive = False def copy_to(self, record): - """Copy fields from this context to the record""" - for key, value in self.__dict__.items(): - setattr(record, key, value) + """Copy logging fields from this context to a log record or + another LoggingContext + """ - record.ru_utime, record.ru_stime = self.get_resource_usage() + # 'request' is the only field we currently use in the logger, so that's + # all we need to copy + record.request = self.request def start(self): if threading.current_thread() is not self.main_thread: diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 570312da84..c899fecf5d 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -68,7 +68,7 @@ class KeyringTestCase(unittest.TestCase): def check_context(self, _, expected): self.assertEquals( - getattr(LoggingContext.current_context(), "test_key", None), + getattr(LoggingContext.current_context(), "request", None), expected ) @@ -82,7 +82,7 @@ class KeyringTestCase(unittest.TestCase): lookup_2_deferred = defer.Deferred() with LoggingContext("one") as context_one: - context_one.test_key = "one" + context_one.request = "one" wait_1_deferred = kr.wait_for_previous_lookups( ["server1"], @@ -96,7 +96,7 @@ class KeyringTestCase(unittest.TestCase): wait_1_deferred.addBoth(self.check_context, "one") with LoggingContext("two") as context_two: - context_two.test_key = "two" + context_two.request = "two" # set off another wait. It should block because the first lookup # hasn't yet completed. @@ -137,7 +137,7 @@ class KeyringTestCase(unittest.TestCase): @defer.inlineCallbacks def get_perspectives(**kwargs): self.assertEquals( - LoggingContext.current_context().test_key, "11", + LoggingContext.current_context().request, "11", ) with logcontext.PreserveLoggingContext(): yield persp_deferred @@ -145,7 +145,7 @@ class KeyringTestCase(unittest.TestCase): self.http_client.post_json.side_effect = get_perspectives with LoggingContext("11") as context_11: - context_11.test_key = "11" + context_11.request = "11" # start off a first set of lookups res_deferreds = kr.verify_json_objects_for_server( @@ -173,7 +173,7 @@ class KeyringTestCase(unittest.TestCase): self.assertIs(LoggingContext.current_context(), context_11) context_12 = LoggingContext("12") - context_12.test_key = "12" + context_12.request = "12" with logcontext.PreserveLoggingContext(context_12): # a second request for a server with outstanding requests # should block rather than start a second call @@ -211,7 +211,7 @@ class KeyringTestCase(unittest.TestCase): sentinel_context = LoggingContext.current_context() with LoggingContext("one") as context_one: - context_one.test_key = "one" + context_one.request = "one" defer = kr.verify_json_for_server("server9", {}) try: diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index e2f7765f49..4850722bc5 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -12,12 +12,12 @@ class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): self.assertEquals( - LoggingContext.current_context().test_key, value + LoggingContext.current_context().request, value ) def test_with_context(self): with LoggingContext() as context_one: - context_one.test_key = "test" + context_one.request = "test" self._check_test_key("test") @defer.inlineCallbacks @@ -25,14 +25,14 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def competing_callback(): with LoggingContext() as competing_context: - competing_context.test_key = "competing" + competing_context.request = "competing" yield sleep(0) self._check_test_key("competing") reactor.callLater(0, competing_callback) with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" yield sleep(0) self._check_test_key("one") @@ -43,14 +43,14 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def cb(): - context_one.test_key = "one" + context_one.request = "one" yield function() self._check_test_key("one") callback_completed[0] = True with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" # fire off function, but don't wait on it. logcontext.preserve_fn(cb)() @@ -107,7 +107,7 @@ class LoggingContextTestCase(unittest.TestCase): sentinel_context = LoggingContext.current_context() with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" d1 = logcontext.make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable @@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase): argument isn't actually a deferred""" with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" d1 = logcontext.make_deferred_yieldable("bum") self._check_test_key("one") -- cgit 1.5.1 From 6324b65f08b3f8dbfee6fef0079e2a87cb1c2c85 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 11 Jan 2018 18:17:54 +0000 Subject: Track db txn time in millisecs ... to reduce the amount of floating-point foo we do. --- synapse/http/server.py | 4 +++- synapse/http/site.py | 6 +++--- synapse/util/logcontext.py | 9 ++++++--- synapse/util/metrics.py | 8 +++++--- 4 files changed, 17 insertions(+), 10 deletions(-) (limited to 'synapse/util') diff --git a/synapse/http/server.py b/synapse/http/server.py index 269b65ca41..0f30e6fd56 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -93,6 +93,8 @@ response_db_txn_count = metrics.register_counter( ), ) +# seconds spent waiting for db txns, excluding scheduling time, when processing +# this request response_db_txn_duration = metrics.register_counter( "response_db_txn_duration_seconds", labels=["method", "servlet", "tag"], alternative_names=( @@ -377,7 +379,7 @@ class RequestMetrics(object): context.db_txn_count, request.method, self.name, tag ) response_db_txn_duration.inc_by( - context.db_txn_duration, request.method, self.name, tag + context.db_txn_duration_ms / 1000., request.method, self.name, tag ) diff --git a/synapse/http/site.py b/synapse/http/site.py index cd1492b1c3..dc64f0f6f5 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -66,10 +66,10 @@ class SynapseRequest(Request): context = LoggingContext.current_context() ru_utime, ru_stime = context.get_resource_usage() db_txn_count = context.db_txn_count - db_txn_duration = context.db_txn_duration + db_txn_duration_ms = context.db_txn_duration_ms except Exception: ru_utime, ru_stime = (0, 0) - db_txn_count, db_txn_duration = (0, 0) + db_txn_count, db_txn_duration_ms = (0, 0) self.site.access_logger.info( "%s - %s - {%s}" @@ -81,7 +81,7 @@ class SynapseRequest(Request): int(time.time() * 1000) - self.start_time, int(ru_utime * 1000), int(ru_stime * 1000), - int(db_txn_duration * 1000), + db_txn_duration_ms, int(db_txn_count), self.sentLength, self.code, diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index ca71a1fc27..a78e53812f 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -59,7 +59,7 @@ class LoggingContext(object): __slots__ = [ "previous_context", "name", "ru_stime", "ru_utime", - "db_txn_count", "db_txn_duration", "usage_start", "usage_end", + "db_txn_count", "db_txn_duration_ms", "usage_start", "usage_end", "main_thread", "alive", "request", "tag", ] @@ -97,7 +97,10 @@ class LoggingContext(object): self.ru_stime = 0. self.ru_utime = 0. self.db_txn_count = 0 - self.db_txn_duration = 0. + + # ms spent waiting for db txns, excluding scheduling time + self.db_txn_duration_ms = 0 + self.usage_start = None self.usage_end = None self.main_thread = threading.current_thread() @@ -205,7 +208,7 @@ class LoggingContext(object): def add_database_transaction(self, duration_ms): self.db_txn_count += 1 - self.db_txn_duration += duration_ms / 1000. + self.db_txn_duration_ms += duration_ms class LoggingContextFilter(logging.Filter): diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 8d22ff3068..d25629cc50 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -72,6 +72,7 @@ block_db_txn_count = metrics.register_counter( ), ) +# seconds spent waiting for db txns, excluding scheduling time, in this block block_db_txn_duration = metrics.register_counter( "block_db_txn_duration_seconds", labels=["block_name"], alternative_names=( @@ -95,7 +96,7 @@ def measure_func(name): class Measure(object): __slots__ = [ "clock", "name", "start_context", "start", "new_context", "ru_utime", - "ru_stime", "db_txn_count", "db_txn_duration", "created_context" + "ru_stime", "db_txn_count", "db_txn_duration_ms", "created_context" ] def __init__(self, clock, name): @@ -115,7 +116,7 @@ class Measure(object): self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() self.db_txn_count = self.start_context.db_txn_count - self.db_txn_duration = self.start_context.db_txn_duration + self.db_txn_duration_ms = self.start_context.db_txn_duration_ms def __exit__(self, exc_type, exc_val, exc_tb): if isinstance(exc_type, Exception) or not self.start_context: @@ -145,7 +146,8 @@ class Measure(object): context.db_txn_count - self.db_txn_count, self.name ) block_db_txn_duration.inc_by( - context.db_txn_duration - self.db_txn_duration, self.name + (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000., + self.name ) if self.created_context: -- cgit 1.5.1 From 3d12d97415ac6d6a4ab8188af31c7df12c5d19f8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 12 Jan 2018 00:27:14 +0000 Subject: Track DB scheduling delay per-request For each request, track the amount of time spent waiting for a db connection. This entails adding it to the LoggingContext and we may as well add metrics for it while we are passing. --- synapse/http/server.py | 7 +++++++ synapse/http/site.py | 4 +++- synapse/storage/_base.py | 4 +++- synapse/util/logcontext.py | 18 +++++++++++++++++- synapse/util/metrics.py | 14 +++++++++++++- 5 files changed, 43 insertions(+), 4 deletions(-) (limited to 'synapse/util') diff --git a/synapse/http/server.py b/synapse/http/server.py index 0f30e6fd56..7b6418bc2c 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -102,6 +102,10 @@ response_db_txn_duration = metrics.register_counter( ), ) +# seconds spent waiting for a db connection, when processing this request +response_db_sched_duration = metrics.register_counter( + "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"] +) _next_request_id = 0 @@ -381,6 +385,9 @@ class RequestMetrics(object): response_db_txn_duration.inc_by( context.db_txn_duration_ms / 1000., request.method, self.name, tag ) + response_db_sched_duration.inc_by( + context.db_sched_duration_ms / 1000., request.method, self.name, tag + ) class RootRedirect(resource.Resource): diff --git a/synapse/http/site.py b/synapse/http/site.py index dc64f0f6f5..e422c8dfae 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -67,13 +67,14 @@ class SynapseRequest(Request): ru_utime, ru_stime = context.get_resource_usage() db_txn_count = context.db_txn_count db_txn_duration_ms = context.db_txn_duration_ms + db_sched_duration_ms = context.db_sched_duration_ms except Exception: ru_utime, ru_stime = (0, 0) db_txn_count, db_txn_duration_ms = (0, 0) self.site.access_logger.info( "%s - %s - {%s}" - " Processed request: %dms (%dms, %dms) (%dms/%d)" + " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)" " %sB %s \"%s %s %s\" \"%s\"", self.getClientIP(), self.site.site_tag, @@ -81,6 +82,7 @@ class SynapseRequest(Request): int(time.time() * 1000) - self.start_time, int(ru_utime * 1000), int(ru_stime * 1000), + db_sched_duration_ms, db_txn_duration_ms, int(db_txn_count), self.sentLength, diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 986617674c..68125006eb 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -347,7 +347,9 @@ class SQLBaseStore(object): def inner_func(conn, *args, **kwargs): with LoggingContext("runWithConnection") as context: - sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) + sched_duration_ms = time.time() * 1000 - start_time + sql_scheduling_timer.inc_by(sched_duration_ms) + current_context.add_database_scheduled(sched_duration_ms) if self.database_engine.is_connection_closed(conn): logger.debug("Reconnecting closed database connection") diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index a78e53812f..94fa7cac98 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -59,7 +59,8 @@ class LoggingContext(object): __slots__ = [ "previous_context", "name", "ru_stime", "ru_utime", - "db_txn_count", "db_txn_duration_ms", "usage_start", "usage_end", + "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms", + "usage_start", "usage_end", "main_thread", "alive", "request", "tag", ] @@ -86,6 +87,9 @@ class LoggingContext(object): def add_database_transaction(self, duration_ms): pass + def add_database_scheduled(self, sched_ms): + pass + def __nonzero__(self): return False @@ -101,6 +105,9 @@ class LoggingContext(object): # ms spent waiting for db txns, excluding scheduling time self.db_txn_duration_ms = 0 + # ms spent waiting for db txns to be scheduled + self.db_sched_duration_ms = 0 + self.usage_start = None self.usage_end = None self.main_thread = threading.current_thread() @@ -210,6 +217,15 @@ class LoggingContext(object): self.db_txn_count += 1 self.db_txn_duration_ms += duration_ms + def add_database_scheduled(self, sched_ms): + """Record a use of the database pool + + Args: + sched_ms (int): number of milliseconds it took us to get a + connection + """ + self.db_sched_duration_ms += sched_ms + class LoggingContextFilter(logging.Filter): """Logging filter that adds values from the current logging context to each diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index d25629cc50..059bb7fedf 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -80,6 +80,11 @@ block_db_txn_duration = metrics.register_counter( ), ) +# seconds spent waiting for a db connection, in this block +block_db_sched_duration = metrics.register_counter( + "block_db_sched_duration_seconds", labels=["block_name"], +) + def measure_func(name): def wrapper(func): @@ -96,7 +101,9 @@ def measure_func(name): class Measure(object): __slots__ = [ "clock", "name", "start_context", "start", "new_context", "ru_utime", - "ru_stime", "db_txn_count", "db_txn_duration_ms", "created_context" + "ru_stime", + "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms", + "created_context", ] def __init__(self, clock, name): @@ -117,6 +124,7 @@ class Measure(object): self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() self.db_txn_count = self.start_context.db_txn_count self.db_txn_duration_ms = self.start_context.db_txn_duration_ms + self.db_sched_duration_ms = self.start_context.db_sched_duration_ms def __exit__(self, exc_type, exc_val, exc_tb): if isinstance(exc_type, Exception) or not self.start_context: @@ -149,6 +157,10 @@ class Measure(object): (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000., self.name ) + block_db_sched_duration.inc_by( + (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000., + self.name + ) if self.created_context: self.start_context.__exit__(exc_type, exc_val, exc_tb) -- cgit 1.5.1 From bc67e7d260631d3fa7bc78653376e15dc0771364 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 17 Jan 2018 16:43:03 +0000 Subject: Add decent impl of a FileConsumer Twisted core doesn't have a general purpose one, so we need to write one ourselves. Features: - All writing happens in background thread - Supports both push and pull producers - Push producers get paused if the consumer falls behind --- synapse/util/file_consumer.py | 158 +++++++++++++++++++++++++++++++++++++++ tests/util/test_file_consumer.py | 138 ++++++++++++++++++++++++++++++++++ 2 files changed, 296 insertions(+) create mode 100644 synapse/util/file_consumer.py create mode 100644 tests/util/test_file_consumer.py (limited to 'synapse/util') diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py new file mode 100644 index 0000000000..de478fcb3e --- /dev/null +++ b/synapse/util/file_consumer.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vecotr Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer, threads, reactor + +from synapse.util.logcontext import make_deferred_yieldable + +import Queue + + +class BackgroundFileConsumer(object): + """A consumer that writes to a file like object. Supports both push + and pull producers + + Args: + file_obj (file): The file like object to write to. Closed when + finished. + """ + + # For PushProducers pause if we have this many unwritten slices + _PAUSE_ON_QUEUE_SIZE = 5 + # And resume once the size of the queue is less than this + _RESUME_ON_QUEUE_SIZE = 2 + + def __init__(self, file_obj): + self.file_obj = file_obj + + # Producer we're registered with + self.producer = None + + # True if PushProducer, false if PullProducer + self.streaming = False + + # Queue of slices of bytes to be written. When producer calls + # unregister a final None is sent. + self.bytes_queue = Queue.Queue() + + # Deferred that is resolved when finished writing + self.finished_deferred = None + + # If the _writer thread throws an exception it gets stored here. + self._write_exception = None + + # A deferred that gets resolved when the bytes_queue gets empty. + # Mainly used for tests. + self._notify_empty_deferred = None + + def registerProducer(self, producer, streaming): + """Part of IProducer interface + + Args: + producer (IProducer) + streaming (bool): True if push based producer, False if pull + based. + """ + self.producer = producer + self.streaming = streaming + self.finished_deferred = threads.deferToThread(self._writer) + if not streaming: + self.producer.resumeProducing() + + self.paused_producer = False + + def unregisterProducer(self): + """Part of IProducer interface + """ + self.producer = None + if not self.finished_deferred.called: + self.bytes_queue.put_nowait(None) + + def write(self, bytes): + """Part of IProducer interface + """ + if self._write_exception: + raise self._write_exception + + if self.finished_deferred.called: + raise Exception("consumer has closed") + + self.bytes_queue.put_nowait(bytes) + + # If this is a pushed based consumer and the queue is getting behind + # then we pause the producer. + if self.streaming and self.bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: + self.paused_producer = True + self.producer.pauseProducing() + + def _writer(self): + """This is run in a background thread to write to the file. + """ + try: + while self.producer or not self.bytes_queue.empty(): + # If we've paused the producer check if we should resume the + # producer. + if self.producer and self.paused_producer: + if self.bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE: + reactor.callFromThread(self._resume_paused_producer) + + if self._notify_empty and self.bytes_queue.empty(): + reactor.callFromThread(self._notify_empty) + + bytes = self.bytes_queue.get() + + # If we get a None (or empty list) then that's a signal used + # to indicate we should check if we should stop. + if bytes: + self.file_obj.write(bytes) + + # If its a pull producer then we need to explicitly ask for + # more stuff. + if not self.streaming and self.producer: + reactor.callFromThread(self.producer.resumeProducing) + except Exception as e: + self._write_exception = e + raise + finally: + self.file_obj.close() + + def wait(self): + """Returns a deferred that resolves when finished writing to file + """ + return make_deferred_yieldable(self.finished_deferred) + + def _resume_paused_producer(self): + """Gets called if we should resume producing after being paused + """ + if self.paused_producer and self.producer: + self.paused_producer = False + self.producer.resumeProducing() + + def _notify_empty(self): + """Called when the _writer thread thinks the queue may be empty and + we should notify anything waiting on `wait_for_writes` + """ + if self._notify_empty_deferred and self.bytes_queue.empty(): + d = self._notify_empty_deferred + self._notify_empty_deferred = None + d.callback(None) + + def wait_for_writes(self): + """Wait for the write queue to be empty and for writes to have + finished. This is mainly useful for tests. + """ + if not self._notify_empty_deferred: + self._notify_empty_deferred = defer.Deferred() + return self._notify_empty_deferred diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py new file mode 100644 index 0000000000..8acb68f0c3 --- /dev/null +++ b/tests/util/test_file_consumer.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer +from mock import NonCallableMock + +from synapse.util.file_consumer import BackgroundFileConsumer + +from tests import unittest +from StringIO import StringIO + +import threading + + +class FileConsumerTests(unittest.TestCase): + + @defer.inlineCallbacks + def test_pull_consumer(self): + string_file = StringIO() + consumer = BackgroundFileConsumer(string_file) + + try: + producer = DummyPullProducer() + + yield producer.register_with_consumer(consumer) + + yield producer.write_and_wait("Foo") + + self.assertEqual(string_file.getvalue(), "Foo") + + yield producer.write_and_wait("Bar") + + self.assertEqual(string_file.getvalue(), "FooBar") + finally: + consumer.unregisterProducer() + + yield consumer.wait() + + self.assertTrue(string_file.closed) + + @defer.inlineCallbacks + def test_push_consumer(self): + string_file = StringIO() + consumer = BackgroundFileConsumer(string_file) + + try: + producer = NonCallableMock(spec_set=[]) + + consumer.registerProducer(producer, True) + + consumer.write("Foo") + yield consumer.wait_for_writes() + + self.assertEqual(string_file.getvalue(), "Foo") + + consumer.write("Bar") + yield consumer.wait_for_writes() + + self.assertEqual(string_file.getvalue(), "FooBar") + finally: + consumer.unregisterProducer() + + yield consumer.wait() + + self.assertTrue(string_file.closed) + + @defer.inlineCallbacks + def test_push_producer_feedback(self): + string_file = BlockingStringWrite() + consumer = BackgroundFileConsumer(string_file) + + try: + producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) + + consumer.registerProducer(producer, True) + + with string_file.write_lock: + for _ in range(consumer._PAUSE_ON_QUEUE_SIZE): + consumer.write("Foo") + + producer.pauseProducing.assert_called_once() + + yield consumer.wait_for_writes() + producer.resumeProducing.assert_called_once() + finally: + consumer.unregisterProducer() + + yield consumer.wait() + + self.assertTrue(string_file.closed) + + +class DummyPullProducer(object): + def __init__(self): + self.consumer = None + self.deferred = defer.Deferred() + + def resumeProducing(self): + d = self.deferred + self.deferred = defer.Deferred() + d.callback(None) + + def write_and_wait(self, bytes): + d = self.deferred + self.consumer.write(bytes) + return d + + def register_with_consumer(self, consumer): + d = self.deferred + self.consumer = consumer + self.consumer.registerProducer(self, False) + return d + + +class BlockingStringWrite(object): + def __init__(self): + self.buffer = "" + self.closed = False + self.write_lock = threading.Lock() + + def write(self, bytes): + self.buffer += bytes + + def close(self): + self.closed = True -- cgit 1.5.1 From a177325b49be4793c8ed21147f8d301a0649a2b6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Jan 2018 11:02:43 +0000 Subject: Fix comments --- synapse/util/file_consumer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/util') diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index de478fcb3e..5284c7967e 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vecotr Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -58,7 +58,7 @@ class BackgroundFileConsumer(object): self._notify_empty_deferred = None def registerProducer(self, producer, streaming): - """Part of IProducer interface + """Part of IConsumer interface Args: producer (IProducer) @@ -91,7 +91,7 @@ class BackgroundFileConsumer(object): self.bytes_queue.put_nowait(bytes) - # If this is a pushed based consumer and the queue is getting behind + # If this is a PushProducer and the queue is getting behind # then we pause the producer. if self.streaming and self.bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: self.paused_producer = True -- cgit 1.5.1 From 28b338ed9bafc2017a635848e14a2a25b78d0016 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Jan 2018 11:04:41 +0000 Subject: Move definition of paused_producer to __init__ --- synapse/util/file_consumer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'synapse/util') diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 5284c7967e..54c9da9573 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -43,6 +43,10 @@ class BackgroundFileConsumer(object): # True if PushProducer, false if PullProducer self.streaming = False + # For PushProducers, indicates whether we've paused the producer and + # need to call resumeProducing before we get more data. + self.paused_producer = False + # Queue of slices of bytes to be written. When producer calls # unregister a final None is sent. self.bytes_queue = Queue.Queue() @@ -71,8 +75,6 @@ class BackgroundFileConsumer(object): if not streaming: self.producer.resumeProducing() - self.paused_producer = False - def unregisterProducer(self): """Part of IProducer interface """ -- cgit 1.5.1 From 17b54389feb3855a33406149a8a59f0327bb3ad1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Jan 2018 11:05:34 +0000 Subject: Fix _notify_empty typo --- synapse/util/file_consumer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/util') diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 54c9da9573..479e480614 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -110,7 +110,7 @@ class BackgroundFileConsumer(object): if self.bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE: reactor.callFromThread(self._resume_paused_producer) - if self._notify_empty and self.bytes_queue.empty(): + if self._notify_empty_deferred and self.bytes_queue.empty(): reactor.callFromThread(self._notify_empty) bytes = self.bytes_queue.get() -- cgit 1.5.1 From dc519602ac0f35d39a70c91f0e6057e865a61dfc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Jan 2018 11:07:17 +0000 Subject: Ensure we registerProducer isn't called twice --- synapse/util/file_consumer.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'synapse/util') diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 479e480614..d7bbb0aeb8 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -69,6 +69,9 @@ class BackgroundFileConsumer(object): streaming (bool): True if push based producer, False if pull based. """ + if self.producer: + raise Exception("registerProducer called twice") + self.producer = producer self.streaming = streaming self.finished_deferred = threads.deferToThread(self._writer) -- cgit 1.5.1 From 2f18a2647b6b9cc07c3cc5f2bec3e1bab67d0eea Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Jan 2018 11:10:12 +0000 Subject: Make all fields private --- synapse/util/file_consumer.py | 62 +++++++++++++++++++++---------------------- 1 file changed, 31 insertions(+), 31 deletions(-) (limited to 'synapse/util') diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index d7bbb0aeb8..d19d48665c 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -35,24 +35,24 @@ class BackgroundFileConsumer(object): _RESUME_ON_QUEUE_SIZE = 2 def __init__(self, file_obj): - self.file_obj = file_obj + self._file_obj = file_obj # Producer we're registered with - self.producer = None + self._producer = None # True if PushProducer, false if PullProducer self.streaming = False # For PushProducers, indicates whether we've paused the producer and # need to call resumeProducing before we get more data. - self.paused_producer = False + self._paused_producer = False # Queue of slices of bytes to be written. When producer calls # unregister a final None is sent. - self.bytes_queue = Queue.Queue() + self._bytes_queue = Queue.Queue() # Deferred that is resolved when finished writing - self.finished_deferred = None + self._finished_deferred = None # If the _writer thread throws an exception it gets stored here. self._write_exception = None @@ -69,21 +69,21 @@ class BackgroundFileConsumer(object): streaming (bool): True if push based producer, False if pull based. """ - if self.producer: + if self._producer: raise Exception("registerProducer called twice") - self.producer = producer + self._producer = producer self.streaming = streaming - self.finished_deferred = threads.deferToThread(self._writer) + self._finished_deferred = threads.deferToThread(self._writer) if not streaming: - self.producer.resumeProducing() + self._producer.resumeProducing() def unregisterProducer(self): """Part of IProducer interface """ - self.producer = None - if not self.finished_deferred.called: - self.bytes_queue.put_nowait(None) + self._producer = None + if not self._finished_deferred.called: + self._bytes_queue.put_nowait(None) def write(self, bytes): """Part of IProducer interface @@ -91,65 +91,65 @@ class BackgroundFileConsumer(object): if self._write_exception: raise self._write_exception - if self.finished_deferred.called: + if self._finished_deferred.called: raise Exception("consumer has closed") - self.bytes_queue.put_nowait(bytes) + self._bytes_queue.put_nowait(bytes) # If this is a PushProducer and the queue is getting behind # then we pause the producer. - if self.streaming and self.bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: - self.paused_producer = True - self.producer.pauseProducing() + if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: + self._paused_producer = True + self._producer.pauseProducing() def _writer(self): """This is run in a background thread to write to the file. """ try: - while self.producer or not self.bytes_queue.empty(): + while self._producer or not self._bytes_queue.empty(): # If we've paused the producer check if we should resume the # producer. - if self.producer and self.paused_producer: - if self.bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE: + if self._producer and self._paused_producer: + if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE: reactor.callFromThread(self._resume_paused_producer) - if self._notify_empty_deferred and self.bytes_queue.empty(): + if self._notify_empty_deferred and self._bytes_queue.empty(): reactor.callFromThread(self._notify_empty) - bytes = self.bytes_queue.get() + bytes = self._bytes_queue.get() # If we get a None (or empty list) then that's a signal used # to indicate we should check if we should stop. if bytes: - self.file_obj.write(bytes) + self._file_obj.write(bytes) # If its a pull producer then we need to explicitly ask for # more stuff. - if not self.streaming and self.producer: - reactor.callFromThread(self.producer.resumeProducing) + if not self.streaming and self._producer: + reactor.callFromThread(self._producer.resumeProducing) except Exception as e: self._write_exception = e raise finally: - self.file_obj.close() + self._file_obj.close() def wait(self): """Returns a deferred that resolves when finished writing to file """ - return make_deferred_yieldable(self.finished_deferred) + return make_deferred_yieldable(self._finished_deferred) def _resume_paused_producer(self): """Gets called if we should resume producing after being paused """ - if self.paused_producer and self.producer: - self.paused_producer = False - self.producer.resumeProducing() + if self._paused_producer and self._producer: + self._paused_producer = False + self._producer.resumeProducing() def _notify_empty(self): """Called when the _writer thread thinks the queue may be empty and we should notify anything waiting on `wait_for_writes` """ - if self._notify_empty_deferred and self.bytes_queue.empty(): + if self._notify_empty_deferred and self._bytes_queue.empty(): d = self._notify_empty_deferred self._notify_empty_deferred = None d.callback(None) -- cgit 1.5.1 From 1432f7ccd5a01e43d0c5417f3d2f4a6a0fbf5bfb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Jan 2018 11:53:21 +0000 Subject: Move test stuff to tests --- synapse/util/file_consumer.py | 26 +------------------ tests/util/test_file_consumer.py | 54 ++++++++++++++++++++++++++++++++++------ 2 files changed, 47 insertions(+), 33 deletions(-) (limited to 'synapse/util') diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index d19d48665c..3241035247 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer, threads, reactor +from twisted.internet import threads, reactor from synapse.util.logcontext import make_deferred_yieldable @@ -57,10 +57,6 @@ class BackgroundFileConsumer(object): # If the _writer thread throws an exception it gets stored here. self._write_exception = None - # A deferred that gets resolved when the bytes_queue gets empty. - # Mainly used for tests. - self._notify_empty_deferred = None - def registerProducer(self, producer, streaming): """Part of IConsumer interface @@ -113,9 +109,6 @@ class BackgroundFileConsumer(object): if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE: reactor.callFromThread(self._resume_paused_producer) - if self._notify_empty_deferred and self._bytes_queue.empty(): - reactor.callFromThread(self._notify_empty) - bytes = self._bytes_queue.get() # If we get a None (or empty list) then that's a signal used @@ -144,20 +137,3 @@ class BackgroundFileConsumer(object): if self._paused_producer and self._producer: self._paused_producer = False self._producer.resumeProducing() - - def _notify_empty(self): - """Called when the _writer thread thinks the queue may be empty and - we should notify anything waiting on `wait_for_writes` - """ - if self._notify_empty_deferred and self._bytes_queue.empty(): - d = self._notify_empty_deferred - self._notify_empty_deferred = None - d.callback(None) - - def wait_for_writes(self): - """Wait for the write queue to be empty and for writes to have - finished. This is mainly useful for tests. - """ - if not self._notify_empty_deferred: - self._notify_empty_deferred = defer.Deferred() - return self._notify_empty_deferred diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py index 8acb68f0c3..76e2234255 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py @@ -14,7 +14,7 @@ # limitations under the License. -from twisted.internet import defer +from twisted.internet import defer, reactor from mock import NonCallableMock from synapse.util.file_consumer import BackgroundFileConsumer @@ -53,7 +53,7 @@ class FileConsumerTests(unittest.TestCase): @defer.inlineCallbacks def test_push_consumer(self): - string_file = StringIO() + string_file = BlockingStringWrite() consumer = BackgroundFileConsumer(string_file) try: @@ -62,14 +62,14 @@ class FileConsumerTests(unittest.TestCase): consumer.registerProducer(producer, True) consumer.write("Foo") - yield consumer.wait_for_writes() + yield string_file.wait_for_n_writes(1) - self.assertEqual(string_file.getvalue(), "Foo") + self.assertEqual(string_file.buffer, "Foo") consumer.write("Bar") - yield consumer.wait_for_writes() + yield string_file.wait_for_n_writes(2) - self.assertEqual(string_file.getvalue(), "FooBar") + self.assertEqual(string_file.buffer, "FooBar") finally: consumer.unregisterProducer() @@ -85,15 +85,22 @@ class FileConsumerTests(unittest.TestCase): try: producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) + resume_deferred = defer.Deferred() + producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None) + consumer.registerProducer(producer, True) + number_writes = 0 with string_file.write_lock: for _ in range(consumer._PAUSE_ON_QUEUE_SIZE): consumer.write("Foo") + number_writes += 1 producer.pauseProducing.assert_called_once() - yield consumer.wait_for_writes() + yield string_file.wait_for_n_writes(number_writes) + + yield resume_deferred producer.resumeProducing.assert_called_once() finally: consumer.unregisterProducer() @@ -131,8 +138,39 @@ class BlockingStringWrite(object): self.closed = False self.write_lock = threading.Lock() + self._notify_write_deferred = None + self._number_of_writes = 0 + def write(self, bytes): - self.buffer += bytes + with self.write_lock: + self.buffer += bytes + self._number_of_writes += 1 + + reactor.callFromThread(self._notify_write) def close(self): self.closed = True + + def _notify_write(self): + "Called by write to indicate a write happened" + with self.write_lock: + if not self._notify_write_deferred: + return + d = self._notify_write_deferred + self._notify_write_deferred = None + d.callback(None) + + @defer.inlineCallbacks + def wait_for_n_writes(self, n): + "Wait for n writes to have happened" + while True: + with self.write_lock: + if n <= self._number_of_writes: + return + + if not self._notify_write_deferred: + self._notify_write_deferred = defer.Deferred() + + d = self._notify_write_deferred + + yield d -- cgit 1.5.1 From be0dfcd4a29859f4c707c2b3cf1da38c5115d251 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Jan 2018 11:57:23 +0000 Subject: Do logcontexts correctly --- synapse/util/file_consumer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/util') diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 3241035247..90a2608d6f 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -15,7 +15,7 @@ from twisted.internet import threads, reactor -from synapse.util.logcontext import make_deferred_yieldable +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn import Queue @@ -70,7 +70,7 @@ class BackgroundFileConsumer(object): self._producer = producer self.streaming = streaming - self._finished_deferred = threads.deferToThread(self._writer) + self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer) if not streaming: self._producer.resumeProducing() -- cgit 1.5.1 From d57765fc8a1b54dae001bbb97b2b529991292fbc Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 18 Jan 2018 12:23:04 +0000 Subject: Fix bugs in block metrics ... which I introduced in #2785 --- synapse/util/metrics.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'synapse/util') diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 059bb7fedf..e4b5687a4b 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) # total number of times we have hit this block -response_count = metrics.register_counter( +block_counter = metrics.register_counter( "block_count", labels=["block_name"], alternative_names=( @@ -76,7 +76,7 @@ block_db_txn_count = metrics.register_counter( block_db_txn_duration = metrics.register_counter( "block_db_txn_duration_seconds", labels=["block_name"], alternative_names=( - metrics.name_prefix + "_block_db_txn_count:total", + metrics.name_prefix + "_block_db_txn_duration:total", ), ) @@ -131,6 +131,8 @@ class Measure(object): return duration = self.clock.time_msec() - self.start + + block_counter.inc(self.name) block_timer.inc_by(duration, self.name) context = LoggingContext.current_context() -- cgit 1.5.1 From 447f4f0d5f136dcadd5fdc286ded2d6e24a3f686 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Fri, 19 Jan 2018 15:33:55 +0000 Subject: rewrite based on PR feedback: * [ ] split config options into allowed_local_3pids and registrations_require_3pid * [ ] simplify and comment logic for picking registration flows * [ ] fix docstring and move check_3pid_allowed into a new util module * [ ] use check_3pid_allowed everywhere @erikjohnston PTAL --- synapse/config/registration.py | 12 +++-- synapse/handlers/register.py | 15 +++---- synapse/rest/client/v1/register.py | 20 +++------ synapse/rest/client/v2_alpha/_base.py | 21 --------- synapse/rest/client/v2_alpha/account.py | 3 +- synapse/rest/client/v2_alpha/register.py | 75 +++++++++++++++----------------- synapse/util/threepids.py | 45 +++++++++++++++++++ 7 files changed, 102 insertions(+), 89 deletions(-) create mode 100644 synapse/util/threepids.py (limited to 'synapse/util') diff --git a/synapse/config/registration.py b/synapse/config/registration.py index e5e4f77872..336959094b 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -32,6 +32,7 @@ class RegistrationConfig(Config): ) self.registrations_require_3pid = config.get("registrations_require_3pid", []) + self.allowed_local_3pids = config.get("allowed_local_3pids", []) self.registration_shared_secret = config.get("registration_shared_secret") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) @@ -53,11 +54,16 @@ class RegistrationConfig(Config): # Enable registration for new users. enable_registration: False - # Mandate that registrations require a 3PID which matches one or more - # of these 3PIDs. N.B. regexp escape backslashes are doubled (once for - # YAML and once for the regexp itself) + # The user must provide all of the below types of 3PID when registering. # # registrations_require_3pid: + # - email + # - msisdn + + # Mandate that users are only allowed to associate certain formats of + # 3PIDs with accounts on this server. + # + # allowed_local_3pids: # - medium: email # pattern: ".*@matrix\\.org" # - medium: email diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 157ebaf251..9021d4d57f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -15,7 +15,6 @@ """Contains functions for registering clients.""" import logging -import re from twisted.internet import defer @@ -26,6 +25,7 @@ from synapse.http.client import CaptchaServerHttpClient from synapse import types from synapse.types import UserID from synapse.util.async import run_on_reactor +from synapse.util.threepids import check_3pid_allowed from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -308,15 +308,10 @@ class RegistrationHandler(BaseHandler): logger.info("got threepid with medium '%s' and address '%s'", threepid['medium'], threepid['address']) - for constraint in self.hs.config.registrations_require_3pid: - if ( - constraint['medium'] == 'email' and - threepid['medium'] == 'email' and - re.match(constraint['pattern'], threepid['address']) - ): - raise RegistrationError( - 403, "Third party identifier is not allowed" - ) + if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']): + raise RegistrationError( + 403, "Third party identifier is not allowed" + ) @defer.inlineCallbacks def bind_emails(self, user_id, threepidCreds): diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index f793542ad6..5c5fa8f7ab 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -71,22 +71,13 @@ class RegisterRestServlet(ClientV1RestServlet): def on_GET(self, request): - require_email = False - require_msisdn = False - for constraint in self.hs.config.registrations_require_3pid: - if constraint['medium'] == 'email': - require_email = True - elif constraint['medium'] == 'msisdn': - require_msisdn = True - else: - logger.warn( - "Unrecognised 3PID medium %s in registrations_require_3pid" % - constraint['medium'] - ) + require_email = 'email' in self.hs.config.registrations_require_3pid + require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid flows = [] if self.hs.config.enable_registration_captcha: - if require_email or not require_msisdn: + # only support the email-only flow if we don't require MSISDN 3PIDs + if not require_msisdn: flows.extend([ { "type": LoginType.RECAPTCHA, @@ -97,6 +88,7 @@ class RegisterRestServlet(ClientV1RestServlet): ] }, ]) + # only support 3PIDless registration if no 3PIDs are required if not require_email and not require_msisdn: flows.extend([ { @@ -105,6 +97,7 @@ class RegisterRestServlet(ClientV1RestServlet): } ]) else: + # only support the email-only flow if we don't require MSISDN 3PIDs if require_email or not require_msisdn: flows.extend([ { @@ -114,6 +107,7 @@ class RegisterRestServlet(ClientV1RestServlet): ] } ]) + # only support 3PIDless registration if no 3PIDs are required if not require_email and not require_msisdn: flows.extend([ { diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index b286ff0d95..77434937ff 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -60,27 +60,6 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit): filter_timeline_limit) -def check_3pid_allowed(hs, medium, address): - # check whether the HS has whitelisted the given 3PID - - allow = False - if hs.config.registrations_require_3pid: - for constraint in hs.config.registrations_require_3pid: - logger.debug("Checking 3PID %s (%s) against %s (%s)" % ( - address, medium, constraint['pattern'], constraint['medium'] - )) - if ( - medium == constraint['medium'] and - re.match(constraint['pattern'], address) - ): - allow = True - break - else: - allow = True - - return allow - - def interactive_auth_handler(orig): """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 2977ad439f..514bb37da1 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -26,7 +26,8 @@ from synapse.http.servlet import ( ) from synapse.util.async import run_on_reactor from synapse.util.msisdn import phone_number_to_msisdn -from ._base import client_v2_patterns, interactive_auth_handler, check_3pid_allowed +from synapse.util.threepids import check_3pid_allowed +from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 898d8b133a..c3479e29de 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -26,11 +26,11 @@ from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string ) from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.threepids import check_3pid_allowed -from ._base import client_v2_patterns, interactive_auth_handler, check_3pid_allowed +from ._base import client_v2_patterns, interactive_auth_handler import logging -import re import hmac from hashlib import sha1 from synapse.util.async import run_on_reactor @@ -316,41 +316,41 @@ class RegisterRestServlet(RestServlet): if 'x_show_msisdn' in body and body['x_show_msisdn']: show_msisdn = True - require_email = False - require_msisdn = False - for constraint in self.hs.config.registrations_require_3pid: - if constraint['medium'] == 'email': - require_email = True - elif constraint['medium'] == 'msisdn': - require_msisdn = True - else: - logger.warn( - "Unrecognised 3PID medium %s in registrations_require_3pid" % - constraint['medium'] - ) + # FIXME: need a better error than "no auth flow found" for scenarios + # where we required 3PID for registration but the user didn't give one + require_email = 'email' in self.hs.config.registrations_require_3pid + require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid flows = [] if self.hs.config.enable_registration_captcha: + # only support 3PIDless registration if no 3PIDs are required if not require_email and not require_msisdn: flows.extend([[LoginType.RECAPTCHA]]) - if require_email or not require_msisdn: + # only support the email-only flow if we don't require MSISDN 3PIDs + if not require_msisdn: flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]]) if show_msisdn: - if not require_email or require_msisdn: + # only support the MSISDN-only flow if we don't require email 3PIDs + if not require_email: flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]]) + # always let users provide both MSISDN & email flows.extend([ [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], ]) else: + # only support 3PIDless registration if no 3PIDs are required if not require_email and not require_msisdn: flows.extend([[LoginType.DUMMY]]) - if require_email or not require_msisdn: + # only support the email-only flow if we don't require MSISDN 3PIDs + if not require_msisdn: flows.extend([[LoginType.EMAIL_IDENTITY]]) if show_msisdn: + # only support the MSISDN-only flow if we don't require email 3PIDs if not require_email or require_msisdn: flows.extend([[LoginType.MSISDN]]) + # always let users provide both MSISDN & email flows.extend([ [LoginType.MSISDN, LoginType.EMAIL_IDENTITY] ]) @@ -359,30 +359,23 @@ class RegisterRestServlet(RestServlet): flows, body, self.hs.get_ip_from_request(request) ) - # doublecheck that we're not trying to register an denied 3pid. - # the user-facing checks should already have happened when we requested - # a 3PID token to validate them in /register/email/requestToken etc - - for constraint in self.hs.config.registrations_require_3pid: - if ( - constraint['medium'] == 'email' and - auth_result and LoginType.EMAIL_IDENTITY in auth_result and - re.match( - constraint['pattern'], - auth_result[LoginType.EMAIL_IDENTITY].threepid.address - ) - ): - raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED - ) - elif ( - constraint['medium'] == 'msisdn' and - auth_result and LoginType.MSISDN in auth_result and - re.match( - constraint['pattern'], - auth_result[LoginType.MSISDN].threepid.address - ) - ): + # Check that we're not trying to register a denied 3pid. + # + # the user-facing checks will probably already have happened in + # /register/email/requestToken when we requested a 3pid, but that's not + # guaranteed. + + if ( + auth_result and + ( + LoginType.EMAIL_IDENTITY in auth_result or + LoginType.EMAIL_MSISDN in auth_result + ) + ): + medium = auth_result[LoginType.EMAIL_IDENTITY].threepid['medium'] + address = auth_result[LoginType.EMAIL_IDENTITY].threepid['address'] + + if not check_3pid_allowed(self.hs, medium, address): raise SynapseError( 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED ) diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py new file mode 100644 index 0000000000..e921b97796 --- /dev/null +++ b/synapse/util/threepids.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +logger = logging.getLogger(__name__) + + +def check_3pid_allowed(hs, medium, address): + """Checks whether a given format of 3PID is allowed to be used on this HS + + Args: + hs (synapse.server.HomeServer): server + medium (str): 3pid medium - e.g. email, msisdn + address (str): address within that medium (e.g. "wotan@matrix.org") + msisdns need to first have been canonicalised + """ + + if hs.config.allowed_local_3pids: + for constraint in hs.config.allowed_local_3pids: + logger.debug("Checking 3PID %s (%s) against %s (%s)" % ( + address, medium, constraint['pattern'], constraint['medium'] + )) + if ( + medium == constraint['medium'] and + re.match(constraint['pattern'], address) + ): + return True + else: + return True + + return False -- cgit 1.5.1 From 8fe253f19b1c61c38111948cce00a7d260d2925a Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Fri, 19 Jan 2018 18:23:45 +0000 Subject: fix PR nitpicking --- synapse/util/threepids.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'synapse/util') diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index e921b97796..75efa0117b 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -27,13 +27,16 @@ def check_3pid_allowed(hs, medium, address): medium (str): 3pid medium - e.g. email, msisdn address (str): address within that medium (e.g. "wotan@matrix.org") msisdns need to first have been canonicalised + Returns: + bool: whether the 3PID medium/address is allowed to be added to this HS """ if hs.config.allowed_local_3pids: for constraint in hs.config.allowed_local_3pids: - logger.debug("Checking 3PID %s (%s) against %s (%s)" % ( - address, medium, constraint['pattern'], constraint['medium'] - )) + logger.debug( + "Checking 3PID %s (%s) against %s (%s)", + address, medium, constraint['pattern'], constraint['medium'], + ) if ( medium == constraint['medium'] and re.match(constraint['pattern'], address) -- cgit 1.5.1 From ab9f844aaf3662a64dbc4c56077e9fa37bc7d5d0 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Mon, 22 Jan 2018 19:11:18 +0100 Subject: Add federation_domain_whitelist option (#2820) Add federation_domain_whitelist gives a way to restrict which domains your HS is allowed to federate with. useful mainly for gracefully preventing a private but internet-connected HS from trying to federate to the wider public Matrix network --- synapse/api/errors.py | 26 ++++++++++++++++++++++++++ synapse/config/server.py | 22 ++++++++++++++++++++++ synapse/federation/federation_client.py | 5 ++++- synapse/federation/transaction_queue.py | 4 +++- synapse/federation/transport/client.py | 3 +++ synapse/federation/transport/server.py | 9 ++++++++- synapse/handlers/device.py | 4 ++++ synapse/handlers/e2e_keys.py | 8 +++++++- synapse/handlers/federation.py | 4 ++++ synapse/http/matrixfederationclient.py | 28 +++++++++++++++++++++++++++- synapse/rest/key/v2/remote_key_resource.py | 8 ++++++++ synapse/rest/media/v1/media_repository.py | 19 +++++++++++++++++-- synapse/util/retryutils.py | 12 ++++++++++++ tests/utils.py | 1 + 14 files changed, 146 insertions(+), 7 deletions(-) (limited to 'synapse/util') diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 46b0d7b34c..aa15f73f36 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -141,6 +141,32 @@ class RegistrationError(SynapseError): pass +class FederationDeniedError(SynapseError): + """An error raised when the server tries to federate with a server which + is not on its federation whitelist. + + Attributes: + destination (str): The destination which has been denied + """ + + def __init__(self, destination): + """Raised by federation client or server to indicate that we are + are deliberately not attempting to contact a given server because it is + not on our federation whitelist. + + Args: + destination (str): the domain in question + """ + + self.destination = destination + + super(FederationDeniedError, self).__init__( + code=403, + msg="Federation denied with %s." % (self.destination,), + errcode=Codes.FORBIDDEN, + ) + + class InteractiveAuthIncompleteError(Exception): """An error raised when UI auth is not yet complete diff --git a/synapse/config/server.py b/synapse/config/server.py index 436dd8a6fe..8f0b6d1f28 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -55,6 +55,17 @@ class ServerConfig(Config): "block_non_admin_invites", False, ) + # FIXME: federation_domain_whitelist needs sytests + self.federation_domain_whitelist = None + federation_domain_whitelist = config.get( + "federation_domain_whitelist", None + ) + # turn the whitelist into a hash for speed of lookup + if federation_domain_whitelist is not None: + self.federation_domain_whitelist = {} + for domain in federation_domain_whitelist: + self.federation_domain_whitelist[domain] = True + if self.public_baseurl is not None: if self.public_baseurl[-1] != '/': self.public_baseurl += '/' @@ -210,6 +221,17 @@ class ServerConfig(Config): # (except those sent by local server admins). The default is False. # block_non_admin_invites: True + # Restrict federation to the following whitelist of domains. + # N.B. we recommend also firewalling your federation listener to limit + # inbound federation traffic as early as possible, rather than relying + # purely on this application-layer restriction. If not specified, the + # default is to whitelist everything. + # + # federation_domain_whitelist: + # - lon.example.com + # - nyc.example.com + # - syd.example.com + # List of ports that Synapse should listen on, their purpose and their # configuration. listeners: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b1fe03f702..813907f7f2 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -23,7 +23,7 @@ from twisted.internet import defer from synapse.api.constants import Membership from synapse.api.errors import ( - CodeMessageException, HttpResponseException, SynapseError, + CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError ) from synapse.events import builder from synapse.federation.federation_base import ( @@ -266,6 +266,9 @@ class FederationClient(FederationBase): except NotRetryingDestination as e: logger.info(e.message) continue + except FederationDeniedError as e: + logger.info(e.message) + continue except Exception as e: pdu_attempts[destination] = now diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 9d39f46583..a141ec9953 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -19,7 +19,7 @@ from twisted.internet import defer from .persistence import TransactionActions from .units import Transaction, Edu -from synapse.api.errors import HttpResponseException +from synapse.api.errors import HttpResponseException, FederationDeniedError from synapse.util import logcontext, PreserveLoggingContext from synapse.util.async import run_on_reactor from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter @@ -490,6 +490,8 @@ class TransactionQueue(object): (e.retry_last_ts + e.retry_interval) / 1000.0 ), ) + except FederationDeniedError as e: + logger.info(e) except Exception as e: logger.warn( "TX [%s] Failed to send transaction: %s", diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 1f3ce238f6..5488e82985 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -212,6 +212,9 @@ class TransportLayerClient(object): Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. + + Fails with ``FederationDeniedError`` if the remote destination + is not in our federation whitelist """ valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 2b02b021ec..06c16ba4fa 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.urls import FEDERATION_PREFIX as PREFIX -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, SynapseError, FederationDeniedError from synapse.http.server import JsonResource from synapse.http.servlet import ( parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, @@ -81,6 +81,7 @@ class Authenticator(object): self.keyring = hs.get_keyring() self.server_name = hs.hostname self.store = hs.get_datastore() + self.federation_domain_whitelist = hs.config.federation_domain_whitelist # A method just so we can pass 'self' as the authenticator to the Servlets @defer.inlineCallbacks @@ -92,6 +93,12 @@ class Authenticator(object): "signatures": {}, } + if ( + self.federation_domain_whitelist is not None and + self.server_name not in self.federation_domain_whitelist + ): + raise FederationDeniedError(self.server_name) + if content is not None: json_request["content"] = content diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 2152efc692..0e83453851 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.api import errors from synapse.api.constants import EventTypes +from synapse.api.errors import FederationDeniedError from synapse.util import stringutils from synapse.util.async import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -513,6 +514,9 @@ class DeviceListEduUpdater(object): # This makes it more likely that the device lists will # eventually become consistent. return + except FederationDeniedError as e: + logger.info(e) + return except Exception: # TODO: Remember that we are now out of sync and try again # later diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 5af8abf66b..9aa95f89e6 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -19,7 +19,9 @@ import logging from canonicaljson import encode_canonical_json from twisted.internet import defer -from synapse.api.errors import SynapseError, CodeMessageException +from synapse.api.errors import ( + SynapseError, CodeMessageException, FederationDeniedError, +) from synapse.types import get_domain_from_id, UserID from synapse.util.logcontext import preserve_fn, make_deferred_yieldable from synapse.util.retryutils import NotRetryingDestination @@ -140,6 +142,10 @@ class E2eKeysHandler(object): failures[destination] = { "status": 503, "message": "Not ready for retry", } + except FederationDeniedError as e: + failures[destination] = { + "status": 403, "message": "Federation Denied", + } except Exception as e: # include ConnectionRefused and other errors failures[destination] = { diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ac70730885..677532c87b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -22,6 +22,7 @@ from ._base import BaseHandler from synapse.api.errors import ( AuthError, FederationError, StoreError, CodeMessageException, SynapseError, + FederationDeniedError, ) from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.events.validator import EventValidator @@ -782,6 +783,9 @@ class FederationHandler(BaseHandler): except NotRetryingDestination as e: logger.info(e.message) continue + except FederationDeniedError as e: + logger.info(e) + continue except Exception as e: logger.exception( "Failed to backfill from %s because %s", diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 833496b72d..9145405cb0 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -27,7 +27,7 @@ import synapse.metrics from canonicaljson import encode_canonical_json from synapse.api.errors import ( - SynapseError, Codes, HttpResponseException, + SynapseError, Codes, HttpResponseException, FederationDeniedError, ) from signedjson.sign import sign_json @@ -123,11 +123,22 @@ class MatrixFederationHttpClient(object): Fails with ``HTTPRequestException``: if we get an HTTP response code >= 300. + Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist + (May also fail with plenty of other Exceptions for things like DNS failures, connection failures, SSL failures.) """ + if ( + self.hs.config.federation_domain_whitelist and + destination not in self.hs.config.federation_domain_whitelist + ): + raise FederationDeniedError(destination) + limiter = yield synapse.util.retryutils.get_retry_limiter( destination, self.clock, @@ -308,6 +319,9 @@ class MatrixFederationHttpClient(object): Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist """ if not json_data_callback: @@ -368,6 +382,9 @@ class MatrixFederationHttpClient(object): Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist """ def body_callback(method, url_bytes, headers_dict): @@ -422,6 +439,9 @@ class MatrixFederationHttpClient(object): Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist """ logger.debug("get_json args: %s", args) @@ -475,6 +495,9 @@ class MatrixFederationHttpClient(object): Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist """ response = yield self._request( @@ -518,6 +541,9 @@ class MatrixFederationHttpClient(object): Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist """ encoded_args = {} diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index cc2842aa72..17e6079cba 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -93,6 +93,7 @@ class RemoteKey(Resource): self.store = hs.get_datastore() self.version_string = hs.version_string self.clock = hs.get_clock() + self.federation_domain_whitelist = hs.config.federation_domain_whitelist def render_GET(self, request): self.async_render_GET(request) @@ -137,6 +138,13 @@ class RemoteKey(Resource): logger.info("Handling query for keys %r", query) store_queries = [] for server_name, key_ids in query.items(): + if ( + self.federation_domain_whitelist is not None and + server_name not in self.federation_domain_whitelist + ): + logger.debug("Federation denied with %s", server_name) + continue + if not key_ids: key_ids = (None,) for key_id in key_ids: diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 4f56bcf577..485db8577a 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -32,8 +32,9 @@ from .media_storage import MediaStorage from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.util.stringutils import random_string -from synapse.api.errors import SynapseError, HttpResponseException, \ - NotFoundError +from synapse.api.errors import ( + SynapseError, HttpResponseException, NotFoundError, FederationDeniedError, +) from synapse.util.async import Linearizer from synapse.util.stringutils import is_ascii @@ -75,6 +76,8 @@ class MediaRepository(object): self.recently_accessed_remotes = set() self.recently_accessed_locals = set() + self.federation_domain_whitelist = hs.config.federation_domain_whitelist + # List of StorageProviders where we should search for media and # potentially upload to. storage_providers = [] @@ -216,6 +219,12 @@ class MediaRepository(object): Deferred: Resolves once a response has successfully been written to request """ + if ( + self.federation_domain_whitelist is not None and + server_name not in self.federation_domain_whitelist + ): + raise FederationDeniedError(server_name) + self.mark_recently_accessed(server_name, media_id) # We linearize here to ensure that we don't try and download remote @@ -250,6 +259,12 @@ class MediaRepository(object): Returns: Deferred[dict]: The media_info of the file """ + if ( + self.federation_domain_whitelist is not None and + server_name not in self.federation_domain_whitelist + ): + raise FederationDeniedError(server_name) + # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 1adedbb361..47b0bb5eb3 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -26,6 +26,18 @@ logger = logging.getLogger(__name__) class NotRetryingDestination(Exception): def __init__(self, retry_last_ts, retry_interval, destination): + """Raised by the limiter (and federation client) to indicate that we are + are deliberately not attempting to contact a given server. + + Args: + retry_last_ts (int): the unix ts in milliseconds of our last attempt + to contact the server. 0 indicates that the last attempt was + successful or that we've never actually attempted to connect. + retry_interval (int): the time in milliseconds to wait until the next + attempt. + destination (str): the domain in question + """ + msg = "Not retrying server %s." % (destination,) super(NotRetryingDestination, self).__init__(msg) diff --git a/tests/utils.py b/tests/utils.py index 44e5f75093..3116047892 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -57,6 +57,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.worker_app = None config.email_enable_notifs = False config.block_non_admin_invites = False + config.federation_domain_whitelist = None # disable user directory updates, because they get done in the # background, which upsets the test runner. -- cgit 1.5.1 From bc496df192fa20dee933590d5f21a3425388c0d7 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 1 Feb 2018 17:57:51 +0000 Subject: report metrics on number of cache evictions --- synapse/metrics/metric.py | 11 ++++++++++- synapse/util/caches/descriptors.py | 4 ++++ synapse/util/caches/expiringcache.py | 6 +++++- synapse/util/caches/lrucache.py | 28 +++++++++++++++++++++++++--- tests/metrics/test_metric.py | 12 ++++++++++++ 5 files changed, 56 insertions(+), 5 deletions(-) (limited to 'synapse/util') diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py index 1e783e5ff4..ff5aa8c0e1 100644 --- a/synapse/metrics/metric.py +++ b/synapse/metrics/metric.py @@ -193,7 +193,9 @@ class DistributionMetric(object): class CacheMetric(object): - __slots__ = ("name", "cache_name", "hits", "misses", "size_callback") + __slots__ = ( + "name", "cache_name", "hits", "misses", "evicted_size", "size_callback", + ) def __init__(self, name, size_callback, cache_name): self.name = name @@ -201,6 +203,7 @@ class CacheMetric(object): self.hits = 0 self.misses = 0 + self.evicted_size = 0 self.size_callback = size_callback @@ -210,6 +213,9 @@ class CacheMetric(object): def inc_misses(self): self.misses += 1 + def inc_evictions(self, size=1): + self.evicted_size += size + def render(self): size = self.size_callback() hits = self.hits @@ -219,6 +225,9 @@ class CacheMetric(object): """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits), """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total), """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size), + """%s:evicted_size{name="%s"} %d""" % ( + self.name, self.cache_name, self.evicted_size + ), ] diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index af65bfe7b8..bf3a66eae4 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -75,6 +75,7 @@ class Cache(object): self.cache = LruCache( max_size=max_entries, keylen=keylen, cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, + evicted_callback=self._on_evicted, ) self.name = name @@ -83,6 +84,9 @@ class Cache(object): self.thread = None self.metrics = register_cache(name, self.cache) + def _on_evicted(self, evicted_count): + self.metrics.inc_evictions(evicted_count) + def check_thread(self): expected_thread = self.thread if expected_thread is None: diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 6ad53a6390..0aa103eecb 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -79,7 +79,11 @@ class ExpiringCache(object): while self._max_len and len(self) > self._max_len: _key, value = self._cache.popitem(last=False) if self.iterable: - self._size_estimate -= len(value.value) + removed_len = len(value.value) + self.metrics.inc_evictions(removed_len) + self._size_estimate -= removed_len + else: + self.metrics.inc_evictions() def __getitem__(self, key): try: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index cf5fbb679c..f088dd430e 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -49,7 +49,24 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None): + def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None, + evicted_callback=None): + """ + Args: + max_size (int): + + keylen (int): + + cache_type (type): + type of underlying cache to be used. Typically one of dict + or TreeCache. + + size_callback (func(V) -> int | None): + + evicted_callback (func(int)|None): + if not None, called on eviction with the size of the evicted + entry + """ cache = cache_type() self.cache = cache # Used for introspection. list_root = _Node(None, None, None, None) @@ -61,8 +78,10 @@ class LruCache(object): def evict(): while cache_len() > max_size: todelete = list_root.prev_node - delete_node(todelete) + evicted_len = delete_node(todelete) cache.pop(todelete.key, None) + if evicted_callback: + evicted_callback(evicted_len) def synchronized(f): @wraps(f) @@ -111,12 +130,15 @@ class LruCache(object): prev_node.next_node = next_node next_node.prev_node = prev_node + deleted_len = 1 if size_callback: - cached_cache_len[0] -= size_callback(node.value) + deleted_len = size_callback(node.value) + cached_cache_len[0] -= deleted_len for cb in node.callbacks: cb() node.callbacks.clear() + return deleted_len @synchronized def cache_get(key, default=None, callbacks=[]): diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index f85455a5af..39bde6e3f8 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -141,6 +141,7 @@ class CacheMetricTestCase(unittest.TestCase): 'cache:hits{name="cache_name"} 0', 'cache:total{name="cache_name"} 0', 'cache:size{name="cache_name"} 0', + 'cache:evicted_size{name="cache_name"} 0', ]) metric.inc_misses() @@ -150,6 +151,7 @@ class CacheMetricTestCase(unittest.TestCase): 'cache:hits{name="cache_name"} 0', 'cache:total{name="cache_name"} 1', 'cache:size{name="cache_name"} 1', + 'cache:evicted_size{name="cache_name"} 0', ]) metric.inc_hits() @@ -158,4 +160,14 @@ class CacheMetricTestCase(unittest.TestCase): 'cache:hits{name="cache_name"} 1', 'cache:total{name="cache_name"} 2', 'cache:size{name="cache_name"} 1', + 'cache:evicted_size{name="cache_name"} 0', + ]) + + metric.inc_evictions(2) + + self.assertEquals(metric.render(), [ + 'cache:hits{name="cache_name"} 1', + 'cache:total{name="cache_name"} 2', + 'cache:size{name="cache_name"} 1', + 'cache:evicted_size{name="cache_name"} 2', ]) -- cgit 1.5.1 From 3a75de923b9183c073bbddae1e08fae546a11f7a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 1 Mar 2018 12:19:09 +0000 Subject: Rewrite make_deferred_yieldable avoiding inlineCallbacks ... because (a) it's actually simpler (b) it might be marginally more performant? --- synapse/util/logcontext.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) (limited to 'synapse/util') diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 94fa7cac98..a8dea15c1b 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -299,10 +299,6 @@ def preserve_fn(f): Useful for wrapping functions that return a deferred which you don't yield on. """ - def reset_context(result): - LoggingContext.set_current_context(LoggingContext.sentinel) - return result - def g(*args, **kwargs): current = LoggingContext.current_context() res = f(*args, **kwargs) @@ -323,12 +319,11 @@ def preserve_fn(f): # which is supposed to have a single entry and exit point. But # by spawning off another deferred, we are effectively # adding a new exit point.) - res.addBoth(reset_context) + res.addBoth(_set_context_cb, LoggingContext.sentinel) return res return g -@defer.inlineCallbacks def make_deferred_yieldable(deferred): """Given a deferred, make it follow the Synapse logcontext rules: @@ -342,9 +337,16 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to preserve_fn.) """ - with PreserveLoggingContext(): - r = yield deferred - defer.returnValue(r) + if isinstance(deferred, defer.Deferred) and not deferred.called: + prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) + deferred.addBoth(_set_context_cb, prev_context) + return deferred + + +def _set_context_cb(result, context): + """A callback function which just sets the logging context""" + LoggingContext.set_current_context(context) + return result # modules to ignore in `logcontext_tracer` -- cgit 1.5.1 From 20f40348d4ea55cc5b98528673e26bac7396a3cb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 7 Mar 2018 19:59:24 +0000 Subject: Factor run_in_background out from preserve_fn It annoys me that we create temporary function objects when there's really no need for it. Let's factor the gubbins out of preserve_fn and start using it. --- docs/log_contexts.rst | 8 +++---- synapse/util/logcontext.py | 53 +++++++++++++++++++++++++--------------------- 2 files changed, 33 insertions(+), 28 deletions(-) (limited to 'synapse/util') diff --git a/docs/log_contexts.rst b/docs/log_contexts.rst index b19b7fa1ea..82ac4f91e5 100644 --- a/docs/log_contexts.rst +++ b/docs/log_contexts.rst @@ -279,9 +279,9 @@ Obviously that option means that the operations done in that might be fixed by setting a different logcontext via a ``with LoggingContext(...)`` in ``background_operation``). -The second option is to use ``logcontext.preserve_fn``, which wraps a function -so that it doesn't reset the logcontext even when it returns an incomplete -deferred, and adds a callback to the returned deferred to reset the +The second option is to use ``logcontext.run_in_background``, which wraps a +function so that it doesn't reset the logcontext even when it returns an +incomplete deferred, and adds a callback to the returned deferred to reset the logcontext. In other words, it turns a function that follows the Synapse rules about logcontexts and Deferreds into one which behaves more like an external function — the opposite operation to that described in the previous section. @@ -293,7 +293,7 @@ It can be used like this: def do_request_handling(): yield foreground_operation() - logcontext.preserve_fn(background_operation)() + logcontext.run_in_background(background_operation) # this will now be logged against the request context logger.debug("Request handling complete") diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index a8dea15c1b..d660ec785b 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -292,36 +292,41 @@ class PreserveLoggingContext(object): def preserve_fn(f): - """Wraps a function, to ensure that the current context is restored after + """Function decorator which wraps the function with run_in_background""" + def g(*args, **kwargs): + return run_in_background(f, *args, **kwargs) + return g + + +def run_in_background(f, *args, **kwargs): + """Calls a function, ensuring that the current context is restored after return from the function, and that the sentinel context is set once the deferred returned by the funtion completes. Useful for wrapping functions that return a deferred which you don't yield on. """ - def g(*args, **kwargs): - current = LoggingContext.current_context() - res = f(*args, **kwargs) - if isinstance(res, defer.Deferred) and not res.called: - # The function will have reset the context before returning, so - # we need to restore it now. - LoggingContext.set_current_context(current) - - # The original context will be restored when the deferred - # completes, but there is nothing waiting for it, so it will - # get leaked into the reactor or some other function which - # wasn't expecting it. We therefore need to reset the context - # here. - # - # (If this feels asymmetric, consider it this way: we are - # effectively forking a new thread of execution. We are - # probably currently within a ``with LoggingContext()`` block, - # which is supposed to have a single entry and exit point. But - # by spawning off another deferred, we are effectively - # adding a new exit point.) - res.addBoth(_set_context_cb, LoggingContext.sentinel) - return res - return g + current = LoggingContext.current_context() + res = f(*args, **kwargs) + if isinstance(res, defer.Deferred) and not res.called: + # The function will have reset the context before returning, so + # we need to restore it now. + LoggingContext.set_current_context(current) + + # The original context will be restored when the deferred + # completes, but there is nothing waiting for it, so it will + # get leaked into the reactor or some other function which + # wasn't expecting it. We therefore need to reset the context + # here. + # + # (If this feels asymmetric, consider it this way: we are + # effectively forking a new thread of execution. We are + # probably currently within a ``with LoggingContext()`` block, + # which is supposed to have a single entry and exit point. But + # by spawning off another deferred, we are effectively + # adding a new exit point.) + res.addBoth(_set_context_cb, LoggingContext.sentinel) + return res def make_deferred_yieldable(deferred): -- cgit 1.5.1 From 7c7706f42b56dd61f5eb17679aa12247f7058ed5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 15 Mar 2018 15:40:13 +0000 Subject: Fix bug where state cache used lots of memory The state cache bases its size on the sum of the size of entries. The size of the entry is calculated once on insertion, so it is important that the size of entries does not change. The DictionaryCache modified the entries size, which caused the state cache to incorrectly think it was smaller than it actually was. --- synapse/util/caches/dictionary_cache.py | 6 +++++- synapse/util/caches/lrucache.py | 8 ++++---- 2 files changed, 9 insertions(+), 5 deletions(-) (limited to 'synapse/util') diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index d4105822b3..1709e8b429 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -132,9 +132,13 @@ class DictionaryCache(object): self._update_or_insert(key, value, known_absent) def _update_or_insert(self, key, value, known_absent): - entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {})) + # We pop and reinsert as we need to tell the cache the size may have + # changed + + entry = self.cache.pop(key, DictionaryEntry(False, set(), {})) entry.value.update(value) entry.known_absent.update(known_absent) + self.cache[key] = entry def _insert(self, key, value, known_absent): self.cache[key] = DictionaryEntry(True, known_absent, value) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index f088dd430e..a4bf8fa6ae 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -154,14 +154,14 @@ class LruCache(object): def cache_set(key, value, callbacks=[]): node = cache.get(key, None) if node is not None: - if value != node.value: + if node.callbacks and value != node.value: for cb in node.callbacks: cb() node.callbacks.clear() - if size_callback: - cached_cache_len[0] -= size_callback(node.value) - cached_cache_len[0] += size_callback(value) + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + cached_cache_len[0] += size_callback(value) node.callbacks.update(callbacks) -- cgit 1.5.1 From 9a0d783c113ae74c55e409d33219cd77f3662b9f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 19 Mar 2018 11:35:53 +0000 Subject: Add comments --- synapse/util/caches/lrucache.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'synapse/util') diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index a4bf8fa6ae..1c5a982094 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -154,11 +154,18 @@ class LruCache(object): def cache_set(key, value, callbacks=[]): node = cache.get(key, None) if node is not None: + # We sometimes store large objects, e.g. dicts, which cause + # the inequality check to take a long time. So let's only do + # the check if we have some callbacks to call. if node.callbacks and value != node.value: for cb in node.callbacks: cb() node.callbacks.clear() + # We don't bother to protect this by value != node.value as + # generally size_callback will be cheap compared with equality + # checks. (For example, taking the size of two dicts is quicker + # than comparing them for equality.) if size_callback: cached_cache_len[0] -= size_callback(node.value) cached_cache_len[0] += size_callback(value) -- cgit 1.5.1