diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 463a737efa..39e360fe24 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -21,8 +21,13 @@ import mock
from twisted.internet import defer, reactor
from synapse.api.errors import SynapseError
-from synapse.util import logcontext
+from synapse.logging.context import (
+ LoggingContext,
+ PreserveLoggingContext,
+ make_deferred_yieldable,
+)
from synapse.util.caches import descriptors
+from synapse.util.caches.descriptors import cached
from tests import unittest
@@ -32,7 +37,7 @@ logger = logging.getLogger(__name__)
def run_on_reactor():
d = defer.Deferred()
reactor.callLater(0, d.callback, 0)
- return logcontext.make_deferred_yieldable(d)
+ return make_deferred_yieldable(d)
class CacheTestCase(unittest.TestCase):
@@ -51,12 +56,15 @@ class CacheTestCase(unittest.TestCase):
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
- # lookup should return the deferreds
- self.assertIs(cache.get("key1"), d1)
- self.assertIs(cache.get("key2"), d2)
+ # lookup should return observable deferreds
+ self.assertFalse(cache.get("key1").has_called())
+ self.assertFalse(cache.get("key2").has_called())
# let one of the lookups complete
d2.callback("result2")
+
+ # for now at least, the cache will return real results rather than an
+ # observabledeferred
self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation
@@ -88,24 +96,24 @@ class DescriptorTestCase(unittest.TestCase):
obj = Cls()
- obj.mock.return_value = 'fish'
+ obj.mock.return_value = "fish"
r = yield obj.fn(1, 2)
- self.assertEqual(r, 'fish')
+ self.assertEqual(r, "fish")
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
- obj.mock.return_value = 'chips'
+ obj.mock.return_value = "chips"
r = yield obj.fn(1, 3)
- self.assertEqual(r, 'chips')
+ self.assertEqual(r, "chips")
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()
# the two values should now be cached
r = yield obj.fn(1, 2)
- self.assertEqual(r, 'fish')
+ self.assertEqual(r, "fish")
r = yield obj.fn(1, 3)
- self.assertEqual(r, 'chips')
+ self.assertEqual(r, "chips")
obj.mock.assert_not_called()
@defer.inlineCallbacks
@@ -121,27 +129,49 @@ class DescriptorTestCase(unittest.TestCase):
return self.mock(arg1, arg2)
obj = Cls()
- obj.mock.return_value = 'fish'
+ obj.mock.return_value = "fish"
r = yield obj.fn(1, 2)
- self.assertEqual(r, 'fish')
+ self.assertEqual(r, "fish")
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
- obj.mock.return_value = 'chips'
+ obj.mock.return_value = "chips"
r = yield obj.fn(2, 3)
- self.assertEqual(r, 'chips')
+ self.assertEqual(r, "chips")
obj.mock.assert_called_once_with(2, 3)
obj.mock.reset_mock()
# the two values should now be cached; we should be able to vary
# the second argument and still get the cached result.
r = yield obj.fn(1, 4)
- self.assertEqual(r, 'fish')
+ self.assertEqual(r, "fish")
r = yield obj.fn(2, 5)
- self.assertEqual(r, 'chips')
+ self.assertEqual(r, "chips")
obj.mock.assert_not_called()
+ def test_cache_with_sync_exception(self):
+ """If the wrapped function throws synchronously, things should continue to work
+ """
+
+ class Cls(object):
+ @cached()
+ def fn(self, arg1):
+ raise SynapseError(100, "mai spoon iz too big!!1")
+
+ obj = Cls()
+
+ # this should fail immediately
+ d = obj.fn(1)
+ self.failureResultOf(d, SynapseError)
+
+ # ... leaving the cache empty
+ self.assertEqual(len(obj.fn.cache.cache), 0)
+
+ # and a second call should result in a second exception
+ d = obj.fn(1)
+ self.failureResultOf(d, SynapseError)
+
def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""
@@ -153,19 +183,19 @@ class DescriptorTestCase(unittest.TestCase):
def fn(self, arg1):
@defer.inlineCallbacks
def inner_fn():
- with logcontext.PreserveLoggingContext():
+ with PreserveLoggingContext():
yield complete_lookup
- defer.returnValue(1)
+ return 1
return inner_fn()
@defer.inlineCallbacks
def do_lookup():
- with logcontext.LoggingContext() as c1:
+ with LoggingContext() as c1:
c1.name = "c1"
r = yield obj.fn(1)
- self.assertEqual(logcontext.LoggingContext.current_context(), c1)
- defer.returnValue(r)
+ self.assertEqual(LoggingContext.current_context(), c1)
+ return r
def check_result(r):
self.assertEqual(r, 1)
@@ -174,18 +204,12 @@ class DescriptorTestCase(unittest.TestCase):
# set off a deferred which will do a cache lookup
d1 = do_lookup()
- self.assertEqual(
- logcontext.LoggingContext.current_context(),
- logcontext.LoggingContext.sentinel,
- )
+ self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
d1.addCallback(check_result)
# and another
d2 = do_lookup()
- self.assertEqual(
- logcontext.LoggingContext.current_context(),
- logcontext.LoggingContext.sentinel,
- )
+ self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
d2.addCallback(check_result)
# let the lookup complete
@@ -210,29 +234,28 @@ class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def do_lookup():
- with logcontext.LoggingContext() as c1:
+ with LoggingContext() as c1:
c1.name = "c1"
try:
d = obj.fn(1)
self.assertEqual(
- logcontext.LoggingContext.current_context(),
- logcontext.LoggingContext.sentinel,
+ LoggingContext.current_context(), LoggingContext.sentinel
)
yield d
self.fail("No exception thrown")
except SynapseError:
pass
- self.assertEqual(logcontext.LoggingContext.current_context(), c1)
+ self.assertEqual(LoggingContext.current_context(), c1)
+
+ # the cache should now be empty
+ self.assertEqual(len(obj.fn.cache.cache), 0)
obj = Cls()
# set off a deferred which will do a cache lookup
d1 = do_lookup()
- self.assertEqual(
- logcontext.LoggingContext.current_context(),
- logcontext.LoggingContext.sentinel,
- )
+ self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
return d1
@@ -248,32 +271,87 @@ class DescriptorTestCase(unittest.TestCase):
obj = Cls()
- obj.mock.return_value = 'fish'
+ obj.mock.return_value = "fish"
r = yield obj.fn(1, 2, 3)
- self.assertEqual(r, 'fish')
+ self.assertEqual(r, "fish")
obj.mock.assert_called_once_with(1, 2, 3)
obj.mock.reset_mock()
# a call with same params shouldn't call the mock again
r = yield obj.fn(1, 2)
- self.assertEqual(r, 'fish')
+ self.assertEqual(r, "fish")
obj.mock.assert_not_called()
obj.mock.reset_mock()
# a call with different params should call the mock again
- obj.mock.return_value = 'chips'
+ obj.mock.return_value = "chips"
r = yield obj.fn(2, 3)
- self.assertEqual(r, 'chips')
+ self.assertEqual(r, "chips")
obj.mock.assert_called_once_with(2, 3, 3)
obj.mock.reset_mock()
# the two values should now be cached
r = yield obj.fn(1, 2)
- self.assertEqual(r, 'fish')
+ self.assertEqual(r, "fish")
r = yield obj.fn(2, 3)
- self.assertEqual(r, 'chips')
+ self.assertEqual(r, "chips")
+ obj.mock.assert_not_called()
+
+ def test_cache_iterable(self):
+ class Cls(object):
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached(iterable=True)
+ def fn(self, arg1, arg2):
+ return self.mock(arg1, arg2)
+
+ obj = Cls()
+
+ obj.mock.return_value = ["spam", "eggs"]
+ r = obj.fn(1, 2)
+ self.assertEqual(r.result, ["spam", "eggs"])
+ obj.mock.assert_called_once_with(1, 2)
+ obj.mock.reset_mock()
+
+ # a call with different params should call the mock again
+ obj.mock.return_value = ["chips"]
+ r = obj.fn(1, 3)
+ self.assertEqual(r.result, ["chips"])
+ obj.mock.assert_called_once_with(1, 3)
+ obj.mock.reset_mock()
+
+ # the two values should now be cached
+ self.assertEqual(len(obj.fn.cache.cache), 3)
+
+ r = obj.fn(1, 2)
+ self.assertEqual(r.result, ["spam", "eggs"])
+ r = obj.fn(1, 3)
+ self.assertEqual(r.result, ["chips"])
obj.mock.assert_not_called()
+ def test_cache_iterable_with_sync_exception(self):
+ """If the wrapped function throws synchronously, things should continue to work
+ """
+
+ class Cls(object):
+ @descriptors.cached(iterable=True)
+ def fn(self, arg1):
+ raise SynapseError(100, "mai spoon iz too big!!1")
+
+ obj = Cls()
+
+ # this should fail immediately
+ d = obj.fn(1)
+ self.failureResultOf(d, SynapseError)
+
+ # ... leaving the cache empty
+ self.assertEqual(len(obj.fn.cache.cache), 0)
+
+ # and a second call should result in a second exception
+ d = obj.fn(1)
+ self.failureResultOf(d, SynapseError)
+
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
@@ -288,44 +366,41 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
- assert logcontext.LoggingContext.current_context().request == "c1"
+ assert LoggingContext.current_context().request == "c1"
# we want this to behave like an asynchronous function
yield run_on_reactor()
- assert logcontext.LoggingContext.current_context().request == "c1"
- defer.returnValue(self.mock(args1, arg2))
+ assert LoggingContext.current_context().request == "c1"
+ return self.mock(args1, arg2)
- with logcontext.LoggingContext() as c1:
+ with LoggingContext() as c1:
c1.request = "c1"
obj = Cls()
- obj.mock.return_value = {10: 'fish', 20: 'chips'}
+ obj.mock.return_value = {10: "fish", 20: "chips"}
d1 = obj.list_fn([10, 20], 2)
- self.assertEqual(
- logcontext.LoggingContext.current_context(),
- logcontext.LoggingContext.sentinel,
- )
+ self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
r = yield d1
- self.assertEqual(logcontext.LoggingContext.current_context(), c1)
+ self.assertEqual(LoggingContext.current_context(), c1)
obj.mock.assert_called_once_with([10, 20], 2)
- self.assertEqual(r, {10: 'fish', 20: 'chips'})
+ self.assertEqual(r, {10: "fish", 20: "chips"})
obj.mock.reset_mock()
# a call with different params should call the mock again
- obj.mock.return_value = {30: 'peas'}
+ obj.mock.return_value = {30: "peas"}
r = yield obj.list_fn([20, 30], 2)
obj.mock.assert_called_once_with([30], 2)
- self.assertEqual(r, {20: 'chips', 30: 'peas'})
+ self.assertEqual(r, {20: "chips", 30: "peas"})
obj.mock.reset_mock()
# all the values should now be cached
r = yield obj.fn(10, 2)
- self.assertEqual(r, 'fish')
+ self.assertEqual(r, "fish")
r = yield obj.fn(20, 2)
- self.assertEqual(r, 'chips')
+ self.assertEqual(r, "chips")
r = yield obj.fn(30, 2)
- self.assertEqual(r, 'peas')
+ self.assertEqual(r, "peas")
r = yield obj.list_fn([10, 20, 30], 2)
obj.mock.assert_not_called()
- self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
+ self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"})
@defer.inlineCallbacks
def test_invalidate(self):
@@ -343,23 +418,23 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
yield run_on_reactor()
- defer.returnValue(self.mock(args1, arg2))
+ return self.mock(args1, arg2)
obj = Cls()
invalidate0 = mock.Mock()
invalidate1 = mock.Mock()
# cache miss
- obj.mock.return_value = {10: 'fish', 20: 'chips'}
+ obj.mock.return_value = {10: "fish", 20: "chips"}
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
obj.mock.assert_called_once_with([10, 20], 2)
- self.assertEqual(r1, {10: 'fish', 20: 'chips'})
+ self.assertEqual(r1, {10: "fish", 20: "chips"})
obj.mock.reset_mock()
# cache hit
r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
obj.mock.assert_not_called()
- self.assertEqual(r2, {10: 'fish', 20: 'chips'})
+ self.assertEqual(r2, {10: "fish", 20: "chips"})
invalidate0.assert_not_called()
invalidate1.assert_not_called()
diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py
index 03b3c15db6..816795c136 100644
--- a/tests/util/caches/test_ttlcache.py
+++ b/tests/util/caches/test_ttlcache.py
@@ -27,57 +27,57 @@ class CacheTestCase(unittest.TestCase):
def test_get(self):
"""simple set/get tests"""
- self.cache.set('one', '1', 10)
- self.cache.set('two', '2', 20)
- self.cache.set('three', '3', 30)
+ self.cache.set("one", "1", 10)
+ self.cache.set("two", "2", 20)
+ self.cache.set("three", "3", 30)
self.assertEqual(len(self.cache), 3)
- self.assertTrue('one' in self.cache)
- self.assertEqual(self.cache.get('one'), '1')
- self.assertEqual(self.cache['one'], '1')
- self.assertEqual(self.cache.get_with_expiry('one'), ('1', 110))
+ self.assertTrue("one" in self.cache)
+ self.assertEqual(self.cache.get("one"), "1")
+ self.assertEqual(self.cache["one"], "1")
+ self.assertEqual(self.cache.get_with_expiry("one"), ("1", 110, 10))
self.assertEqual(self.cache._metrics.hits, 3)
self.assertEqual(self.cache._metrics.misses, 0)
- self.cache.set('two', '2.5', 20)
- self.assertEqual(self.cache['two'], '2.5')
+ self.cache.set("two", "2.5", 20)
+ self.assertEqual(self.cache["two"], "2.5")
self.assertEqual(self.cache._metrics.hits, 4)
# non-existent-item tests
- self.assertEqual(self.cache.get('four', '4'), '4')
- self.assertIs(self.cache.get('four', None), None)
+ self.assertEqual(self.cache.get("four", "4"), "4")
+ self.assertIs(self.cache.get("four", None), None)
with self.assertRaises(KeyError):
- self.cache['four']
+ self.cache["four"]
with self.assertRaises(KeyError):
- self.cache.get('four')
+ self.cache.get("four")
with self.assertRaises(KeyError):
- self.cache.get_with_expiry('four')
+ self.cache.get_with_expiry("four")
self.assertEqual(self.cache._metrics.hits, 4)
self.assertEqual(self.cache._metrics.misses, 5)
def test_expiry(self):
- self.cache.set('one', '1', 10)
- self.cache.set('two', '2', 20)
- self.cache.set('three', '3', 30)
+ self.cache.set("one", "1", 10)
+ self.cache.set("two", "2", 20)
+ self.cache.set("three", "3", 30)
self.assertEqual(len(self.cache), 3)
- self.assertEqual(self.cache['one'], '1')
- self.assertEqual(self.cache['two'], '2')
+ self.assertEqual(self.cache["one"], "1")
+ self.assertEqual(self.cache["two"], "2")
# enough for the first entry to expire, but not the rest
self.mock_timer.side_effect = lambda: 110.0
self.assertEqual(len(self.cache), 2)
- self.assertFalse('one' in self.cache)
- self.assertEqual(self.cache['two'], '2')
- self.assertEqual(self.cache['three'], '3')
+ self.assertFalse("one" in self.cache)
+ self.assertEqual(self.cache["two"], "2")
+ self.assertEqual(self.cache["three"], "3")
- self.assertEqual(self.cache.get_with_expiry('two'), ('2', 120))
+ self.assertEqual(self.cache.get_with_expiry("two"), ("2", 120, 20))
self.assertEqual(self.cache._metrics.hits, 5)
self.assertEqual(self.cache._metrics.misses, 0)
diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py
index bf85d3b8ec..f60918069a 100644
--- a/tests/util/test_async_utils.py
+++ b/tests/util/test_async_utils.py
@@ -16,9 +16,8 @@ from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.task import Clock
-from synapse.util import logcontext
+from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.util.async_helpers import timeout_deferred
-from synapse.util.logcontext import LoggingContext
from tests.unittest import TestCase
@@ -69,14 +68,14 @@ class TimeoutDeferredTest(TestCase):
@defer.inlineCallbacks
def blocking():
non_completing_d = Deferred()
- with logcontext.PreserveLoggingContext():
+ with PreserveLoggingContext():
try:
yield non_completing_d
except CancelledError:
blocking_was_cancelled[0] = True
raise
- with logcontext.LoggingContext("one") as context_one:
+ with LoggingContext("one") as context_one:
# the errbacks should be run in the test logcontext
def errback(res, deferred_name):
self.assertIs(
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
new file mode 100644
index 0000000000..0ab0a91483
--- /dev/null
+++ b/tests/util/test_itertools.py
@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.util.iterutils import chunk_seq
+
+from tests.unittest import TestCase
+
+
+class ChunkSeqTests(TestCase):
+ def test_short_seq(self):
+ parts = chunk_seq("123", 8)
+
+ self.assertEqual(
+ list(parts), ["123"],
+ )
+
+ def test_long_seq(self):
+ parts = chunk_seq("abcdefghijklmnop", 8)
+
+ self.assertEqual(
+ list(parts), ["abcdefgh", "ijklmnop"],
+ )
+
+ def test_uneven_parts(self):
+ parts = chunk_seq("abcdefghijklmnop", 5)
+
+ self.assertEqual(
+ list(parts), ["abcde", "fghij", "klmno", "p"],
+ )
+
+ def test_empty_input(self):
+ parts = chunk_seq([], 5)
+
+ self.assertEqual(
+ list(parts), [],
+ )
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index ec7ba9719c..0ec8ef90ce 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -19,7 +19,8 @@ from six.moves import range
from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError
-from synapse.util import Clock, logcontext
+from synapse.logging.context import LoggingContext
+from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
from tests import unittest
@@ -51,13 +52,13 @@ class LinearizerTestCase(unittest.TestCase):
@defer.inlineCallbacks
def func(i, sleep=False):
- with logcontext.LoggingContext("func(%s)" % i) as lc:
+ with LoggingContext("func(%s)" % i) as lc:
with (yield linearizer.queue("")):
- self.assertEqual(logcontext.LoggingContext.current_context(), lc)
+ self.assertEqual(LoggingContext.current_context(), lc)
if sleep:
yield Clock(reactor).sleep(0)
- self.assertEqual(logcontext.LoggingContext.current_context(), lc)
+ self.assertEqual(LoggingContext.current_context(), lc)
func(0, sleep=True)
for i in range(1, 100):
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 8adaee3c8d..281b32c4b8 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -1,8 +1,14 @@
import twisted.python.failure
from twisted.internet import defer, reactor
-from synapse.util import Clock, logcontext
-from synapse.util.logcontext import LoggingContext
+from synapse.logging.context import (
+ LoggingContext,
+ PreserveLoggingContext,
+ make_deferred_yieldable,
+ nested_logging_context,
+ run_in_background,
+)
+from synapse.util import Clock
from .. import unittest
@@ -39,24 +45,17 @@ class LoggingContextTestCase(unittest.TestCase):
callback_completed = [False]
- def test():
+ with LoggingContext() as context_one:
context_one.request = "one"
- d = function()
+
+ # fire off function, but don't wait on it.
+ d2 = run_in_background(function)
def cb(res):
- self._check_test_key("one")
callback_completed[0] = True
return res
- d.addCallback(cb)
-
- return d
-
- with LoggingContext() as context_one:
- context_one.request = "one"
-
- # fire off function, but don't wait on it.
- logcontext.run_in_background(test)
+ d2.addCallback(cb)
self._check_test_key("one")
@@ -92,7 +91,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_run_in_background_with_non_blocking_fn(self):
@defer.inlineCallbacks
def nonblocking_function():
- with logcontext.PreserveLoggingContext():
+ with PreserveLoggingContext():
yield defer.succeed(None)
return self._test_run_in_background(nonblocking_function)
@@ -101,7 +100,23 @@ class LoggingContextTestCase(unittest.TestCase):
# a function which returns a deferred which looks like it has been
# called, but is actually paused
def testfunc():
- return logcontext.make_deferred_yieldable(_chained_deferred_function())
+ return make_deferred_yieldable(_chained_deferred_function())
+
+ return self._test_run_in_background(testfunc)
+
+ def test_run_in_background_with_coroutine(self):
+ async def testfunc():
+ self._check_test_key("one")
+ d = Clock(reactor).sleep(0)
+ self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ await d
+ self._check_test_key("one")
+
+ return self._test_run_in_background(testfunc)
+
+ def test_run_in_background_with_nonblocking_coroutine(self):
+ async def testfunc():
+ self._check_test_key("one")
return self._test_run_in_background(testfunc)
@@ -119,7 +134,7 @@ class LoggingContextTestCase(unittest.TestCase):
with LoggingContext() as context_one:
context_one.request = "one"
- d1 = logcontext.make_deferred_yieldable(blocking_function())
+ d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(LoggingContext.current_context(), sentinel_context)
@@ -135,7 +150,7 @@ class LoggingContextTestCase(unittest.TestCase):
with LoggingContext() as context_one:
context_one.request = "one"
- d1 = logcontext.make_deferred_yieldable(_chained_deferred_function())
+ d1 = make_deferred_yieldable(_chained_deferred_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(LoggingContext.current_context(), sentinel_context)
@@ -152,7 +167,7 @@ class LoggingContextTestCase(unittest.TestCase):
with LoggingContext() as context_one:
context_one.request = "one"
- d1 = logcontext.make_deferred_yieldable("bum")
+ d1 = make_deferred_yieldable("bum")
self._check_test_key("one")
r = yield d1
@@ -161,9 +176,33 @@ class LoggingContextTestCase(unittest.TestCase):
def test_nested_logging_context(self):
with LoggingContext(request="foo"):
- nested_context = logcontext.nested_logging_context(suffix="bar")
+ nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.request, "foo-bar")
+ @defer.inlineCallbacks
+ def test_make_deferred_yieldable_with_await(self):
+ # an async function which retuns an incomplete coroutine, but doesn't
+ # follow the synapse rules.
+
+ async def blocking_function():
+ d = defer.Deferred()
+ reactor.callLater(0, d.callback, None)
+ await d
+
+ sentinel_context = LoggingContext.current_context()
+
+ with LoggingContext() as context_one:
+ context_one.request = "one"
+
+ d1 = make_deferred_yieldable(blocking_function())
+ # make sure that the context was reset by make_deferred_yieldable
+ self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+ yield d1
+
+ # now it should be restored
+ self._check_test_key("one")
+
# a function which returns a deferred which has been "called", but
# which had a function which returned another incomplete deferred on
diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py
index 297aebbfbe..0fb60caacb 100644
--- a/tests/util/test_logformatter.py
+++ b/tests/util/test_logformatter.py
@@ -14,7 +14,7 @@
# limitations under the License.
import sys
-from synapse.util.logformatter import LogFormatter
+from synapse.logging.formatter import LogFormatter
from tests import unittest
diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py
new file mode 100644
index 0000000000..4d1aee91d5
--- /dev/null
+++ b/tests/util/test_ratelimitutils.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.config.homeserver import HomeServerConfig
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+from tests.server import get_clock
+from tests.unittest import TestCase
+from tests.utils import default_config
+
+
+class FederationRateLimiterTestCase(TestCase):
+ def test_ratelimit(self):
+ """A simple test with the default values"""
+ reactor, clock = get_clock()
+ rc_config = build_rc_config()
+ ratelimiter = FederationRateLimiter(clock, rc_config)
+
+ with ratelimiter.ratelimit("testhost") as d1:
+ # shouldn't block
+ self.successResultOf(d1)
+
+ def test_concurrent_limit(self):
+ """Test what happens when we hit the concurrent limit"""
+ reactor, clock = get_clock()
+ rc_config = build_rc_config({"rc_federation": {"concurrent": 2}})
+ ratelimiter = FederationRateLimiter(clock, rc_config)
+
+ with ratelimiter.ratelimit("testhost") as d1:
+ # shouldn't block
+ self.successResultOf(d1)
+
+ cm2 = ratelimiter.ratelimit("testhost")
+ d2 = cm2.__enter__()
+ # also shouldn't block
+ self.successResultOf(d2)
+
+ cm3 = ratelimiter.ratelimit("testhost")
+ d3 = cm3.__enter__()
+ # this one should block, though ...
+ self.assertNoResult(d3)
+
+ # ... until we complete an earlier request
+ cm2.__exit__(None, None, None)
+ self.successResultOf(d3)
+
+ def test_sleep_limit(self):
+ """Test what happens when we hit the sleep limit"""
+ reactor, clock = get_clock()
+ rc_config = build_rc_config(
+ {"rc_federation": {"sleep_limit": 2, "sleep_delay": 500}}
+ )
+ ratelimiter = FederationRateLimiter(clock, rc_config)
+
+ with ratelimiter.ratelimit("testhost") as d1:
+ # shouldn't block
+ self.successResultOf(d1)
+
+ with ratelimiter.ratelimit("testhost") as d2:
+ # nor this
+ self.successResultOf(d2)
+
+ with ratelimiter.ratelimit("testhost") as d3:
+ # this one should block, though ...
+ self.assertNoResult(d3)
+ sleep_time = _await_resolution(reactor, d3)
+ self.assertAlmostEqual(sleep_time, 500, places=3)
+
+
+def _await_resolution(reactor, d):
+ """advance the clock until the deferred completes.
+
+ Returns the number of milliseconds it took to complete.
+ """
+ start_time = reactor.seconds()
+ while not d.called:
+ reactor.advance(0.01)
+ return (reactor.seconds() - start_time) * 1000
+
+
+def build_rc_config(settings={}):
+ config_dict = default_config("test")
+ config_dict.update(settings)
+ config = HomeServerConfig()
+ config.parse_config_dict(config_dict, "", "")
+ return config.rc_federation
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
new file mode 100644
index 0000000000..9e348694ad
--- /dev/null
+++ b/tests/util/test_retryutils.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.util.retryutils import (
+ MIN_RETRY_INTERVAL,
+ RETRY_MULTIPLIER,
+ NotRetryingDestination,
+ get_retry_limiter,
+)
+
+from tests.unittest import HomeserverTestCase
+
+
+class RetryLimiterTestCase(HomeserverTestCase):
+ def test_new_destination(self):
+ """A happy-path case with a new destination and a successful operation"""
+ store = self.hs.get_datastore()
+ d = get_retry_limiter("test_dest", self.clock, store)
+ self.pump()
+ limiter = self.successResultOf(d)
+
+ # advance the clock a bit before making the request
+ self.pump(1)
+
+ with limiter:
+ pass
+
+ d = store.get_destination_retry_timings("test_dest")
+ self.pump()
+ new_timings = self.successResultOf(d)
+ self.assertIsNone(new_timings)
+
+ def test_limiter(self):
+ """General test case which walks through the process of a failing request"""
+ store = self.hs.get_datastore()
+
+ d = get_retry_limiter("test_dest", self.clock, store)
+ self.pump()
+ limiter = self.successResultOf(d)
+
+ self.pump(1)
+ try:
+ with limiter:
+ self.pump(1)
+ failure_ts = self.clock.time_msec()
+ raise AssertionError("argh")
+ except AssertionError:
+ pass
+
+ # wait for the update to land
+ self.pump()
+
+ d = store.get_destination_retry_timings("test_dest")
+ self.pump()
+ new_timings = self.successResultOf(d)
+ self.assertEqual(new_timings["failure_ts"], failure_ts)
+ self.assertEqual(new_timings["retry_last_ts"], failure_ts)
+ self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)
+
+ # now if we try again we should get a failure
+ d = get_retry_limiter("test_dest", self.clock, store)
+ self.pump()
+ self.failureResultOf(d, NotRetryingDestination)
+
+ #
+ # advance the clock and try again
+ #
+
+ self.pump(MIN_RETRY_INTERVAL)
+ d = get_retry_limiter("test_dest", self.clock, store)
+ self.pump()
+ limiter = self.successResultOf(d)
+
+ self.pump(1)
+ try:
+ with limiter:
+ self.pump(1)
+ retry_ts = self.clock.time_msec()
+ raise AssertionError("argh")
+ except AssertionError:
+ pass
+
+ # wait for the update to land
+ self.pump()
+
+ d = store.get_destination_retry_timings("test_dest")
+ self.pump()
+ new_timings = self.successResultOf(d)
+ self.assertEqual(new_timings["failure_ts"], failure_ts)
+ self.assertEqual(new_timings["retry_last_ts"], retry_ts)
+ self.assertGreaterEqual(
+ new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5
+ )
+ self.assertLessEqual(
+ new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0
+ )
+
+ #
+ # one more go, with success
+ #
+ self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
+ d = get_retry_limiter("test_dest", self.clock, store)
+ self.pump()
+ limiter = self.successResultOf(d)
+
+ self.pump(1)
+ with limiter:
+ self.pump(1)
+
+ # wait for the update to land
+ self.pump()
+
+ d = store.get_destination_retry_timings("test_dest")
+ self.pump()
+ new_timings = self.successResultOf(d)
+ self.assertIsNone(new_timings)
diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py
deleted file mode 100644
index 1a44f72425..0000000000
--- a/tests/util/test_snapshot_cache.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet.defer import Deferred
-
-from synapse.util.caches.snapshot_cache import SnapshotCache
-
-from .. import unittest
-
-
-class SnapshotCacheTestCase(unittest.TestCase):
- def setUp(self):
- self.cache = SnapshotCache()
- self.cache.DURATION_MS = 1
-
- def test_get_set(self):
- # Check that getting a missing key returns None
- self.assertEquals(self.cache.get(0, "key"), None)
-
- # Check that setting a key with a deferred returns
- # a deferred that resolves when the initial deferred does
- d = Deferred()
- set_result = self.cache.set(0, "key", d)
- self.assertIsNotNone(set_result)
- self.assertFalse(set_result.called)
-
- # Check that getting the key before the deferred has resolved
- # returns a deferred that resolves when the initial deferred does.
- get_result_at_10 = self.cache.get(10, "key")
- self.assertIsNotNone(get_result_at_10)
- self.assertFalse(get_result_at_10.called)
-
- # Check that the returned deferreds resolve when the initial deferred
- # does.
- d.callback("v")
- self.assertTrue(set_result.called)
- self.assertTrue(get_result_at_10.called)
-
- # Check that getting the key after the deferred has resolved
- # before the cache expires returns a resolved deferred.
- get_result_at_11 = self.cache.get(11, "key")
- self.assertIsNotNone(get_result_at_11)
- if isinstance(get_result_at_11, Deferred):
- # The cache may return the actual result rather than a deferred
- self.assertTrue(get_result_at_11.called)
-
- # Check that getting the key after the deferred has resolved
- # after the cache expires returns None
- get_result_at_12 = self.cache.get(12, "key")
- self.assertIsNone(get_result_at_12)
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index f2be63706b..72a9de5370 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -67,7 +67,7 @@ class StreamChangeCacheTests(unittest.TestCase):
# If we update an existing entity, it keeps the two existing entities
cache.entity_has_changed("bar@baz.net", 5)
self.assertEqual(
- set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key)
+ {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
)
def test_get_all_entities_changed(self):
@@ -137,7 +137,7 @@ class StreamChangeCacheTests(unittest.TestCase):
cache.get_entities_changed(
["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2
),
- set(["bar@baz.net", "user@elsewhere.org"]),
+ {"bar@baz.net", "user@elsewhere.org"},
)
# Query all the entries mid-way through the stream, but include one
@@ -153,7 +153,7 @@ class StreamChangeCacheTests(unittest.TestCase):
],
stream_pos=2,
),
- set(["bar@baz.net", "user@elsewhere.org"]),
+ {"bar@baz.net", "user@elsewhere.org"},
)
# Query all the entries, but before the first known point. We will get
@@ -168,21 +168,13 @@ class StreamChangeCacheTests(unittest.TestCase):
],
stream_pos=0,
),
- set(
- [
- "user@foo.com",
- "bar@baz.net",
- "user@elsewhere.org",
- "not@here.website",
- ]
- ),
+ {"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"},
)
# Query a subset of the entries mid-way through the stream. We should
# only get back the subset.
self.assertEqual(
- cache.get_entities_changed(["bar@baz.net"], stream_pos=2),
- set(["bar@baz.net"]),
+ cache.get_entities_changed(["bar@baz.net"], stream_pos=2), {"bar@baz.net"},
)
def test_max_pos(self):
diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
new file mode 100644
index 0000000000..4f4da29a98
--- /dev/null
+++ b/tests/util/test_stringutils.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.api.errors import SynapseError
+from synapse.util.stringutils import assert_valid_client_secret
+
+from .. import unittest
+
+
+class StringUtilsTestCase(unittest.TestCase):
+ def test_client_secret_regex(self):
+ """Ensure that client_secret does not contain illegal characters"""
+ good = [
+ "abcde12345",
+ "ABCabc123",
+ "_--something==_",
+ "...--==-18913",
+ "8Dj2odd-e9asd.cd==_--ddas-secret-",
+ # We temporarily allow : characters: https://github.com/matrix-org/synapse/issues/6766
+ # To be removed in a future release
+ "SECRET:1234567890",
+ ]
+
+ bad = [
+ "--+-/secret",
+ "\\dx--dsa288",
+ "",
+ "AAS//",
+ "asdj**",
+ ">X><Z<!!-)))",
+ "a@b.com",
+ ]
+
+ for client_secret in good:
+ assert_valid_client_secret(client_secret)
+
+ for client_secret in bad:
+ with self.assertRaises(SynapseError):
+ assert_valid_client_secret(client_secret)
|