diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 2516fe40f4..8176a7dabd 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -13,20 +13,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from functools import partial
import logging
+from functools import partial
import mock
+
+from twisted.internet import defer, reactor
+
from synapse.api.errors import SynapseError
-from synapse.util import async
from synapse.util import logcontext
-from twisted.internet import defer
from synapse.util.caches import descriptors
+
from tests import unittest
logger = logging.getLogger(__name__)
+def run_on_reactor():
+ d = defer.Deferred()
+ reactor.callLater(0, d.callback, 0)
+ return logcontext.make_deferred_yieldable(d)
+
+
class CacheTestCase(unittest.TestCase):
def test_invalidate_all(self):
cache = descriptors.Cache("testcache")
@@ -195,7 +203,8 @@ class DescriptorTestCase(unittest.TestCase):
def fn(self, arg1):
@defer.inlineCallbacks
def inner_fn():
- yield async.run_on_reactor()
+ # we want this to behave like an asynchronous function
+ yield run_on_reactor()
raise SynapseError(400, "blah")
return inner_fn()
@@ -205,7 +214,12 @@ class DescriptorTestCase(unittest.TestCase):
with logcontext.LoggingContext() as c1:
c1.name = "c1"
try:
- yield obj.fn(1)
+ d = obj.fn(1)
+ self.assertEqual(
+ logcontext.LoggingContext.current_context(),
+ logcontext.LoggingContext.sentinel,
+ )
+ yield d
self.fail("No exception thrown")
except SynapseError:
pass
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index bc92f85fa6..26f2fa5800 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -14,10 +14,10 @@
# limitations under the License.
-from tests import unittest
-
from synapse.util.caches.dictionary_cache import DictionaryCache
+from tests import unittest
+
class DictCacheTestCase(unittest.TestCase):
@@ -32,7 +32,7 @@ class DictCacheTestCase(unittest.TestCase):
seq = self.cache.sequence
test_value = {"test": "test_simple_cache_hit_full"}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key)
self.assertEqual(test_value, c.value)
@@ -44,7 +44,7 @@ class DictCacheTestCase(unittest.TestCase):
test_value = {
"test": "test_simple_cache_hit_partial"
}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test"])
self.assertEqual(test_value, c.value)
@@ -56,7 +56,7 @@ class DictCacheTestCase(unittest.TestCase):
test_value = {
"test": "test_simple_cache_miss_partial"
}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test2"])
self.assertEqual({}, c.value)
@@ -70,7 +70,7 @@ class DictCacheTestCase(unittest.TestCase):
"test2": "test_simple_cache_hit_miss_partial2",
"test3": "test_simple_cache_hit_miss_partial3",
}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test2"])
self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
@@ -82,13 +82,13 @@ class DictCacheTestCase(unittest.TestCase):
test_value_1 = {
"test": "test_simple_cache_hit_miss_partial",
}
- self.cache.update(seq, key, test_value_1, full=False)
+ self.cache.update(seq, key, test_value_1, fetched_keys=set("test"))
seq = self.cache.sequence
test_value_2 = {
"test2": "test_simple_cache_hit_miss_partial2",
}
- self.cache.update(seq, key, test_value_2, full=False)
+ self.cache.update(seq, key, test_value_2, fetched_keys=set("test2"))
c = self.cache.get(key)
self.assertEqual(
diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
index 31d24adb8b..d12b5e838b 100644
--- a/tests/util/test_expiring_cache.py
+++ b/tests/util/test_expiring_cache.py
@@ -14,12 +14,12 @@
# limitations under the License.
-from .. import unittest
-
from synapse.util.caches.expiringcache import ExpiringCache
from tests.utils import MockClock
+from .. import unittest
+
class ExpiringCacheTestCase(unittest.TestCase):
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index d6e1082779..7ce5f8c258 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -14,15 +14,16 @@
# limitations under the License.
-from twisted.internet import defer, reactor
+import threading
+
from mock import NonCallableMock
+from six import StringIO
+
+from twisted.internet import defer, reactor
from synapse.util.file_consumer import BackgroundFileConsumer
from tests import unittest
-from six import StringIO
-
-import threading
class FileConsumerTests(unittest.TestCase):
@@ -30,7 +31,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_pull_consumer(self):
string_file = StringIO()
- consumer = BackgroundFileConsumer(string_file)
+ consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = DummyPullProducer()
@@ -54,7 +55,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_push_consumer(self):
string_file = BlockingStringWrite()
- consumer = BackgroundFileConsumer(string_file)
+ consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = NonCallableMock(spec_set=[])
@@ -80,7 +81,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_push_producer_feedback(self):
string_file = BlockingStringWrite()
- consumer = BackgroundFileConsumer(string_file)
+ consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py
index 9c795d9fdb..a5a767b1ff 100644
--- a/tests/util/test_limiter.py
+++ b/tests/util/test_limiter.py
@@ -14,12 +14,12 @@
# limitations under the License.
-from tests import unittest
-
from twisted.internet import defer
from synapse.util.async import Limiter
+from tests import unittest
+
class LimiterTestCase(unittest.TestCase):
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 4865eb4bc6..c95907b32c 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -12,13 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util import async, logcontext
-from tests import unittest
-from twisted.internet import defer
+from six.moves import range
+
+from twisted.internet import defer, reactor
+from synapse.util import Clock, logcontext
from synapse.util.async import Linearizer
-from six.moves import range
+
+from tests import unittest
class LinearizerTestCase(unittest.TestCase):
@@ -53,7 +55,7 @@ class LinearizerTestCase(unittest.TestCase):
self.assertEqual(
logcontext.LoggingContext.current_context(), lc)
if sleep:
- yield async.sleep(0)
+ yield Clock(reactor).sleep(0)
self.assertEqual(
logcontext.LoggingContext.current_context(), lc)
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index ad78d884e0..c54001f7a4 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -1,12 +1,11 @@
import twisted.python.failure
-from twisted.internet import defer
-from twisted.internet import reactor
-from .. import unittest
+from twisted.internet import defer, reactor
-from synapse.util.async import sleep
-from synapse.util import logcontext
+from synapse.util import Clock, logcontext
from synapse.util.logcontext import LoggingContext
+from .. import unittest
+
class LoggingContextTestCase(unittest.TestCase):
@@ -22,18 +21,20 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_sleep(self):
+ clock = Clock(reactor)
+
@defer.inlineCallbacks
def competing_callback():
with LoggingContext() as competing_context:
competing_context.request = "competing"
- yield sleep(0)
+ yield clock.sleep(0)
self._check_test_key("competing")
reactor.callLater(0, competing_callback)
with LoggingContext() as context_one:
context_one.request = "one"
- yield sleep(0)
+ yield clock.sleep(0)
self._check_test_key("one")
def _test_run_in_background(self, function):
@@ -87,7 +88,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_run_in_background_with_blocking_fn(self):
@defer.inlineCallbacks
def blocking_function():
- yield sleep(0)
+ yield Clock(reactor).sleep(0)
return self._test_run_in_background(blocking_function)
diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py
index 1a1a8412f2..297aebbfbe 100644
--- a/tests/util/test_logformatter.py
+++ b/tests/util/test_logformatter.py
@@ -15,6 +15,7 @@
import sys
from synapse.util.logformatter import LogFormatter
+
from tests import unittest
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index dfb78cb8bd..9b36ef4482 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -14,12 +14,12 @@
# limitations under the License.
-from .. import unittest
+from mock import Mock
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
-from mock import Mock
+from .. import unittest
class LruCacheTestCase(unittest.TestCase):
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index 1d745ae1a7..24194e3b25 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -14,10 +14,10 @@
# limitations under the License.
-from tests import unittest
-
from synapse.util.async import ReadWriteLock
+from tests import unittest
+
class ReadWriteLockTestCase(unittest.TestCase):
diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py
index d3a8630c2f..0f5b32fcc0 100644
--- a/tests/util/test_snapshot_cache.py
+++ b/tests/util/test_snapshot_cache.py
@@ -14,10 +14,11 @@
# limitations under the License.
-from .. import unittest
+from twisted.internet.defer import Deferred
from synapse.util.caches.snapshot_cache import SnapshotCache
-from twisted.internet.defer import Deferred
+
+from .. import unittest
class SnapshotCacheTestCase(unittest.TestCase):
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
new file mode 100644
index 0000000000..e3897c0d19
--- /dev/null
+++ b/tests/util/test_stream_change_cache.py
@@ -0,0 +1,199 @@
+from mock import patch
+
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from tests import unittest
+
+
+class StreamChangeCacheTests(unittest.TestCase):
+ """
+ Tests for StreamChangeCache.
+ """
+
+ def test_prefilled_cache(self):
+ """
+ Providing a prefilled cache to StreamChangeCache will result in a cache
+ with the prefilled-cache entered in.
+ """
+ cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2})
+ self.assertTrue(cache.has_entity_changed("user@foo.com", 1))
+
+ def test_has_entity_changed(self):
+ """
+ StreamChangeCache.entity_has_changed will mark entities as changed, and
+ has_entity_changed will observe the changed entities.
+ """
+ cache = StreamChangeCache("#test", 3)
+
+ cache.entity_has_changed("user@foo.com", 6)
+ cache.entity_has_changed("bar@baz.net", 7)
+
+ # If it's been changed after that stream position, return True
+ self.assertTrue(cache.has_entity_changed("user@foo.com", 4))
+ self.assertTrue(cache.has_entity_changed("bar@baz.net", 4))
+
+ # If it's been changed at that stream position, return False
+ self.assertFalse(cache.has_entity_changed("user@foo.com", 6))
+
+ # If there's no changes after that stream position, return False
+ self.assertFalse(cache.has_entity_changed("user@foo.com", 7))
+
+ # If the entity does not exist, return False.
+ self.assertFalse(cache.has_entity_changed("not@here.website", 7))
+
+ # If we request before the stream cache's earliest known position,
+ # return True, whether it's a known entity or not.
+ self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
+ self.assertTrue(cache.has_entity_changed("not@here.website", 0))
+
+ @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0)
+ def test_has_entity_changed_pops_off_start(self):
+ """
+ StreamChangeCache.entity_has_changed will respect the max size and
+ purge the oldest items upon reaching that max size.
+ """
+ cache = StreamChangeCache("#test", 1, max_size=2)
+
+ cache.entity_has_changed("user@foo.com", 2)
+ cache.entity_has_changed("bar@baz.net", 3)
+ cache.entity_has_changed("user@elsewhere.org", 4)
+
+ # The cache is at the max size, 2
+ self.assertEqual(len(cache._cache), 2)
+
+ # The oldest item has been popped off
+ self.assertTrue("user@foo.com" not in cache._entity_to_key)
+
+ # If we update an existing entity, it keeps the two existing entities
+ cache.entity_has_changed("bar@baz.net", 5)
+ self.assertEqual(
+ set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key)
+ )
+
+ def test_get_all_entities_changed(self):
+ """
+ StreamChangeCache.get_all_entities_changed will return all changed
+ entities since the given position. If the position is before the start
+ of the known stream, it returns None instead.
+ """
+ cache = StreamChangeCache("#test", 1)
+
+ cache.entity_has_changed("user@foo.com", 2)
+ cache.entity_has_changed("bar@baz.net", 3)
+ cache.entity_has_changed("user@elsewhere.org", 4)
+
+ self.assertEqual(
+ cache.get_all_entities_changed(1),
+ ["user@foo.com", "bar@baz.net", "user@elsewhere.org"],
+ )
+ self.assertEqual(
+ cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"]
+ )
+ self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
+ self.assertEqual(cache.get_all_entities_changed(0), None)
+
+ def test_has_any_entity_changed(self):
+ """
+ StreamChangeCache.has_any_entity_changed will return True if any
+ entities have been changed since the provided stream position, and
+ False if they have not. If the cache has entries and the provided
+ stream position is before it, it will return True, otherwise False if
+ the cache has no entries.
+ """
+ cache = StreamChangeCache("#test", 1)
+
+ # With no entities, it returns False for the past, present, and future.
+ self.assertFalse(cache.has_any_entity_changed(0))
+ self.assertFalse(cache.has_any_entity_changed(1))
+ self.assertFalse(cache.has_any_entity_changed(2))
+
+ # We add an entity
+ cache.entity_has_changed("user@foo.com", 2)
+
+ # With an entity, it returns True for the past, the stream start
+ # position, and False for the stream position the entity was changed
+ # on and ones after it.
+ self.assertTrue(cache.has_any_entity_changed(0))
+ self.assertTrue(cache.has_any_entity_changed(1))
+ self.assertFalse(cache.has_any_entity_changed(2))
+ self.assertFalse(cache.has_any_entity_changed(3))
+
+ def test_get_entities_changed(self):
+ """
+ StreamChangeCache.get_entities_changed will return the entities in the
+ given list that have changed since the provided stream ID. If the
+ stream position is earlier than the earliest known position, it will
+ return all of the entities queried for.
+ """
+ cache = StreamChangeCache("#test", 1)
+
+ cache.entity_has_changed("user@foo.com", 2)
+ cache.entity_has_changed("bar@baz.net", 3)
+ cache.entity_has_changed("user@elsewhere.org", 4)
+
+ # Query all the entries, but mid-way through the stream. We should only
+ # get the ones after that point.
+ self.assertEqual(
+ cache.get_entities_changed(
+ ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2
+ ),
+ set(["bar@baz.net", "user@elsewhere.org"]),
+ )
+
+ # Query all the entries mid-way through the stream, but include one
+ # that doesn't exist in it. We should get back the one that doesn't
+ # exist, too.
+ self.assertEqual(
+ cache.get_entities_changed(
+ [
+ "user@foo.com",
+ "bar@baz.net",
+ "user@elsewhere.org",
+ "not@here.website",
+ ],
+ stream_pos=2,
+ ),
+ set(["bar@baz.net", "user@elsewhere.org", "not@here.website"]),
+ )
+
+ # Query all the entries, but before the first known point. We will get
+ # all the entries we queried for, including ones that don't exist.
+ self.assertEqual(
+ cache.get_entities_changed(
+ [
+ "user@foo.com",
+ "bar@baz.net",
+ "user@elsewhere.org",
+ "not@here.website",
+ ],
+ stream_pos=0,
+ ),
+ set(
+ [
+ "user@foo.com",
+ "bar@baz.net",
+ "user@elsewhere.org",
+ "not@here.website",
+ ]
+ ),
+ )
+
+ def test_max_pos(self):
+ """
+ StreamChangeCache.get_max_pos_of_last_change will return the most
+ recent point where the entity could have changed. If the entity is not
+ known, the stream start is provided instead.
+ """
+ cache = StreamChangeCache("#test", 1)
+
+ cache.entity_has_changed("user@foo.com", 2)
+ cache.entity_has_changed("bar@baz.net", 3)
+ cache.entity_has_changed("user@elsewhere.org", 4)
+
+ # Known entities will return the point where they were changed.
+ self.assertEqual(cache.get_max_pos_of_last_change("user@foo.com"), 2)
+ self.assertEqual(cache.get_max_pos_of_last_change("bar@baz.net"), 3)
+ self.assertEqual(cache.get_max_pos_of_last_change("user@elsewhere.org"), 4)
+
+ # Unknown entities will return the stream start position.
+ self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), 1)
diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py
index 7ab578a185..a5f2261208 100644
--- a/tests/util/test_treecache.py
+++ b/tests/util/test_treecache.py
@@ -14,10 +14,10 @@
# limitations under the License.
-from .. import unittest
-
from synapse.util.caches.treecache import TreeCache
+from .. import unittest
+
class TreeCacheTestCase(unittest.TestCase):
def test_get_set_onelevel(self):
diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py
index fdb24a48b0..03201a4d9b 100644
--- a/tests/util/test_wheel_timer.py
+++ b/tests/util/test_wheel_timer.py
@@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .. import unittest
-
from synapse.util.wheel_timer import WheelTimer
+from .. import unittest
+
class WheelTimerTestCase(unittest.TestCase):
def test_single_insert_fetch(self):
|