diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 570312da84..c899fecf5d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -68,7 +68,7 @@ class KeyringTestCase(unittest.TestCase):
def check_context(self, _, expected):
self.assertEquals(
- getattr(LoggingContext.current_context(), "test_key", None),
+ getattr(LoggingContext.current_context(), "request", None),
expected
)
@@ -82,7 +82,7 @@ class KeyringTestCase(unittest.TestCase):
lookup_2_deferred = defer.Deferred()
with LoggingContext("one") as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
wait_1_deferred = kr.wait_for_previous_lookups(
["server1"],
@@ -96,7 +96,7 @@ class KeyringTestCase(unittest.TestCase):
wait_1_deferred.addBoth(self.check_context, "one")
with LoggingContext("two") as context_two:
- context_two.test_key = "two"
+ context_two.request = "two"
# set off another wait. It should block because the first lookup
# hasn't yet completed.
@@ -137,7 +137,7 @@ class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def get_perspectives(**kwargs):
self.assertEquals(
- LoggingContext.current_context().test_key, "11",
+ LoggingContext.current_context().request, "11",
)
with logcontext.PreserveLoggingContext():
yield persp_deferred
@@ -145,7 +145,7 @@ class KeyringTestCase(unittest.TestCase):
self.http_client.post_json.side_effect = get_perspectives
with LoggingContext("11") as context_11:
- context_11.test_key = "11"
+ context_11.request = "11"
# start off a first set of lookups
res_deferreds = kr.verify_json_objects_for_server(
@@ -173,7 +173,7 @@ class KeyringTestCase(unittest.TestCase):
self.assertIs(LoggingContext.current_context(), context_11)
context_12 = LoggingContext("12")
- context_12.test_key = "12"
+ context_12.request = "12"
with logcontext.PreserveLoggingContext(context_12):
# a second request for a server with outstanding requests
# should block rather than start a second call
@@ -211,7 +211,7 @@ class KeyringTestCase(unittest.TestCase):
sentinel_context = LoggingContext.current_context()
with LoggingContext("one") as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
defer = kr.verify_json_for_server("server9", {})
try:
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
new file mode 100644
index 0000000000..76e2234255
--- /dev/null
+++ b/tests/util/test_file_consumer.py
@@ -0,0 +1,176 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer, reactor
+from mock import NonCallableMock
+
+from synapse.util.file_consumer import BackgroundFileConsumer
+
+from tests import unittest
+from StringIO import StringIO
+
+import threading
+
+
+class FileConsumerTests(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def test_pull_consumer(self):
+ string_file = StringIO()
+ consumer = BackgroundFileConsumer(string_file)
+
+ try:
+ producer = DummyPullProducer()
+
+ yield producer.register_with_consumer(consumer)
+
+ yield producer.write_and_wait("Foo")
+
+ self.assertEqual(string_file.getvalue(), "Foo")
+
+ yield producer.write_and_wait("Bar")
+
+ self.assertEqual(string_file.getvalue(), "FooBar")
+ finally:
+ consumer.unregisterProducer()
+
+ yield consumer.wait()
+
+ self.assertTrue(string_file.closed)
+
+ @defer.inlineCallbacks
+ def test_push_consumer(self):
+ string_file = BlockingStringWrite()
+ consumer = BackgroundFileConsumer(string_file)
+
+ try:
+ producer = NonCallableMock(spec_set=[])
+
+ consumer.registerProducer(producer, True)
+
+ consumer.write("Foo")
+ yield string_file.wait_for_n_writes(1)
+
+ self.assertEqual(string_file.buffer, "Foo")
+
+ consumer.write("Bar")
+ yield string_file.wait_for_n_writes(2)
+
+ self.assertEqual(string_file.buffer, "FooBar")
+ finally:
+ consumer.unregisterProducer()
+
+ yield consumer.wait()
+
+ self.assertTrue(string_file.closed)
+
+ @defer.inlineCallbacks
+ def test_push_producer_feedback(self):
+ string_file = BlockingStringWrite()
+ consumer = BackgroundFileConsumer(string_file)
+
+ try:
+ producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
+
+ resume_deferred = defer.Deferred()
+ producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None)
+
+ consumer.registerProducer(producer, True)
+
+ number_writes = 0
+ with string_file.write_lock:
+ for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
+ consumer.write("Foo")
+ number_writes += 1
+
+ producer.pauseProducing.assert_called_once()
+
+ yield string_file.wait_for_n_writes(number_writes)
+
+ yield resume_deferred
+ producer.resumeProducing.assert_called_once()
+ finally:
+ consumer.unregisterProducer()
+
+ yield consumer.wait()
+
+ self.assertTrue(string_file.closed)
+
+
+class DummyPullProducer(object):
+ def __init__(self):
+ self.consumer = None
+ self.deferred = defer.Deferred()
+
+ def resumeProducing(self):
+ d = self.deferred
+ self.deferred = defer.Deferred()
+ d.callback(None)
+
+ def write_and_wait(self, bytes):
+ d = self.deferred
+ self.consumer.write(bytes)
+ return d
+
+ def register_with_consumer(self, consumer):
+ d = self.deferred
+ self.consumer = consumer
+ self.consumer.registerProducer(self, False)
+ return d
+
+
+class BlockingStringWrite(object):
+ def __init__(self):
+ self.buffer = ""
+ self.closed = False
+ self.write_lock = threading.Lock()
+
+ self._notify_write_deferred = None
+ self._number_of_writes = 0
+
+ def write(self, bytes):
+ with self.write_lock:
+ self.buffer += bytes
+ self._number_of_writes += 1
+
+ reactor.callFromThread(self._notify_write)
+
+ def close(self):
+ self.closed = True
+
+ def _notify_write(self):
+ "Called by write to indicate a write happened"
+ with self.write_lock:
+ if not self._notify_write_deferred:
+ return
+ d = self._notify_write_deferred
+ self._notify_write_deferred = None
+ d.callback(None)
+
+ @defer.inlineCallbacks
+ def wait_for_n_writes(self, n):
+ "Wait for n writes to have happened"
+ while True:
+ with self.write_lock:
+ if n <= self._number_of_writes:
+ return
+
+ if not self._notify_write_deferred:
+ self._notify_write_deferred = defer.Deferred()
+
+ d = self._notify_write_deferred
+
+ yield d
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index e2f7765f49..4850722bc5 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -12,12 +12,12 @@ class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value):
self.assertEquals(
- LoggingContext.current_context().test_key, value
+ LoggingContext.current_context().request, value
)
def test_with_context(self):
with LoggingContext() as context_one:
- context_one.test_key = "test"
+ context_one.request = "test"
self._check_test_key("test")
@defer.inlineCallbacks
@@ -25,14 +25,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def competing_callback():
with LoggingContext() as competing_context:
- competing_context.test_key = "competing"
+ competing_context.request = "competing"
yield sleep(0)
self._check_test_key("competing")
reactor.callLater(0, competing_callback)
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
yield sleep(0)
self._check_test_key("one")
@@ -43,14 +43,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def cb():
- context_one.test_key = "one"
+ context_one.request = "one"
yield function()
self._check_test_key("one")
callback_completed[0] = True
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
# fire off function, but don't wait on it.
logcontext.preserve_fn(cb)()
@@ -107,7 +107,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = LoggingContext.current_context()
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
d1 = logcontext.make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
@@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase):
argument isn't actually a deferred"""
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
d1 = logcontext.make_deferred_yieldable("bum")
self._check_test_key("one")
|