summary refs log tree commit diff
path: root/tests/util
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util')
-rw-r--r--tests/util/caches/test_descriptors.py24
-rw-r--r--tests/util/test_dict_cache.py16
-rw-r--r--tests/util/test_expiring_cache.py4
-rw-r--r--tests/util/test_file_consumer.py15
-rw-r--r--tests/util/test_limiter.py4
-rw-r--r--tests/util/test_linearizer.py12
-rw-r--r--tests/util/test_logcontext.py17
-rw-r--r--tests/util/test_logformatter.py1
-rw-r--r--tests/util/test_lrucache.py4
-rw-r--r--tests/util/test_rwlock.py4
-rw-r--r--tests/util/test_snapshot_cache.py5
-rw-r--r--tests/util/test_stream_change_cache.py199
-rw-r--r--tests/util/test_treecache.py4
-rw-r--r--tests/util/test_wheel_timer.py4
14 files changed, 266 insertions, 47 deletions
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 2516fe40f4..8176a7dabd 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -13,20 +13,28 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from functools import partial
 import logging
+from functools import partial
 
 import mock
+
+from twisted.internet import defer, reactor
+
 from synapse.api.errors import SynapseError
-from synapse.util import async
 from synapse.util import logcontext
-from twisted.internet import defer
 from synapse.util.caches import descriptors
+
 from tests import unittest
 
 logger = logging.getLogger(__name__)
 
 
+def run_on_reactor():
+    d = defer.Deferred()
+    reactor.callLater(0, d.callback, 0)
+    return logcontext.make_deferred_yieldable(d)
+
+
 class CacheTestCase(unittest.TestCase):
     def test_invalidate_all(self):
         cache = descriptors.Cache("testcache")
@@ -195,7 +203,8 @@ class DescriptorTestCase(unittest.TestCase):
             def fn(self, arg1):
                 @defer.inlineCallbacks
                 def inner_fn():
-                    yield async.run_on_reactor()
+                    # we want this to behave like an asynchronous function
+                    yield run_on_reactor()
                     raise SynapseError(400, "blah")
 
                 return inner_fn()
@@ -205,7 +214,12 @@ class DescriptorTestCase(unittest.TestCase):
             with logcontext.LoggingContext() as c1:
                 c1.name = "c1"
                 try:
-                    yield obj.fn(1)
+                    d = obj.fn(1)
+                    self.assertEqual(
+                        logcontext.LoggingContext.current_context(),
+                        logcontext.LoggingContext.sentinel,
+                    )
+                    yield d
                     self.fail("No exception thrown")
                 except SynapseError:
                     pass
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index bc92f85fa6..26f2fa5800 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -14,10 +14,10 @@
 # limitations under the License.
 
 
-from tests import unittest
-
 from synapse.util.caches.dictionary_cache import DictionaryCache
 
+from tests import unittest
+
 
 class DictCacheTestCase(unittest.TestCase):
 
@@ -32,7 +32,7 @@ class DictCacheTestCase(unittest.TestCase):
 
         seq = self.cache.sequence
         test_value = {"test": "test_simple_cache_hit_full"}
-        self.cache.update(seq, key, test_value, full=True)
+        self.cache.update(seq, key, test_value)
 
         c = self.cache.get(key)
         self.assertEqual(test_value, c.value)
@@ -44,7 +44,7 @@ class DictCacheTestCase(unittest.TestCase):
         test_value = {
             "test": "test_simple_cache_hit_partial"
         }
-        self.cache.update(seq, key, test_value, full=True)
+        self.cache.update(seq, key, test_value)
 
         c = self.cache.get(key, ["test"])
         self.assertEqual(test_value, c.value)
@@ -56,7 +56,7 @@ class DictCacheTestCase(unittest.TestCase):
         test_value = {
             "test": "test_simple_cache_miss_partial"
         }
-        self.cache.update(seq, key, test_value, full=True)
+        self.cache.update(seq, key, test_value)
 
         c = self.cache.get(key, ["test2"])
         self.assertEqual({}, c.value)
@@ -70,7 +70,7 @@ class DictCacheTestCase(unittest.TestCase):
             "test2": "test_simple_cache_hit_miss_partial2",
             "test3": "test_simple_cache_hit_miss_partial3",
         }
-        self.cache.update(seq, key, test_value, full=True)
+        self.cache.update(seq, key, test_value)
 
         c = self.cache.get(key, ["test2"])
         self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
@@ -82,13 +82,13 @@ class DictCacheTestCase(unittest.TestCase):
         test_value_1 = {
             "test": "test_simple_cache_hit_miss_partial",
         }
-        self.cache.update(seq, key, test_value_1, full=False)
+        self.cache.update(seq, key, test_value_1, fetched_keys=set("test"))
 
         seq = self.cache.sequence
         test_value_2 = {
             "test2": "test_simple_cache_hit_miss_partial2",
         }
-        self.cache.update(seq, key, test_value_2, full=False)
+        self.cache.update(seq, key, test_value_2, fetched_keys=set("test2"))
 
         c = self.cache.get(key)
         self.assertEqual(
diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
index 31d24adb8b..d12b5e838b 100644
--- a/tests/util/test_expiring_cache.py
+++ b/tests/util/test_expiring_cache.py
@@ -14,12 +14,12 @@
 # limitations under the License.
 
 
-from .. import unittest
-
 from synapse.util.caches.expiringcache import ExpiringCache
 
 from tests.utils import MockClock
 
+from .. import unittest
+
 
 class ExpiringCacheTestCase(unittest.TestCase):
 
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index d6e1082779..7ce5f8c258 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -14,15 +14,16 @@
 # limitations under the License.
 
 
-from twisted.internet import defer, reactor
+import threading
+
 from mock import NonCallableMock
+from six import StringIO
+
+from twisted.internet import defer, reactor
 
 from synapse.util.file_consumer import BackgroundFileConsumer
 
 from tests import unittest
-from six import StringIO
-
-import threading
 
 
 class FileConsumerTests(unittest.TestCase):
@@ -30,7 +31,7 @@ class FileConsumerTests(unittest.TestCase):
     @defer.inlineCallbacks
     def test_pull_consumer(self):
         string_file = StringIO()
-        consumer = BackgroundFileConsumer(string_file)
+        consumer = BackgroundFileConsumer(string_file, reactor=reactor)
 
         try:
             producer = DummyPullProducer()
@@ -54,7 +55,7 @@ class FileConsumerTests(unittest.TestCase):
     @defer.inlineCallbacks
     def test_push_consumer(self):
         string_file = BlockingStringWrite()
-        consumer = BackgroundFileConsumer(string_file)
+        consumer = BackgroundFileConsumer(string_file, reactor=reactor)
 
         try:
             producer = NonCallableMock(spec_set=[])
@@ -80,7 +81,7 @@ class FileConsumerTests(unittest.TestCase):
     @defer.inlineCallbacks
     def test_push_producer_feedback(self):
         string_file = BlockingStringWrite()
-        consumer = BackgroundFileConsumer(string_file)
+        consumer = BackgroundFileConsumer(string_file, reactor=reactor)
 
         try:
             producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py
index 9c795d9fdb..a5a767b1ff 100644
--- a/tests/util/test_limiter.py
+++ b/tests/util/test_limiter.py
@@ -14,12 +14,12 @@
 # limitations under the License.
 
 
-from tests import unittest
-
 from twisted.internet import defer
 
 from synapse.util.async import Limiter
 
+from tests import unittest
+
 
 class LimiterTestCase(unittest.TestCase):
 
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 4865eb4bc6..c95907b32c 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -12,13 +12,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.util import async, logcontext
-from tests import unittest
 
-from twisted.internet import defer
+from six.moves import range
+
+from twisted.internet import defer, reactor
 
+from synapse.util import Clock, logcontext
 from synapse.util.async import Linearizer
-from six.moves import range
+
+from tests import unittest
 
 
 class LinearizerTestCase(unittest.TestCase):
@@ -53,7 +55,7 @@ class LinearizerTestCase(unittest.TestCase):
                     self.assertEqual(
                         logcontext.LoggingContext.current_context(), lc)
                     if sleep:
-                        yield async.sleep(0)
+                        yield Clock(reactor).sleep(0)
 
                 self.assertEqual(
                     logcontext.LoggingContext.current_context(), lc)
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index ad78d884e0..c54001f7a4 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -1,12 +1,11 @@
 import twisted.python.failure
-from twisted.internet import defer
-from twisted.internet import reactor
-from .. import unittest
+from twisted.internet import defer, reactor
 
-from synapse.util.async import sleep
-from synapse.util import logcontext
+from synapse.util import Clock, logcontext
 from synapse.util.logcontext import LoggingContext
 
+from .. import unittest
+
 
 class LoggingContextTestCase(unittest.TestCase):
 
@@ -22,18 +21,20 @@ class LoggingContextTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_sleep(self):
+        clock = Clock(reactor)
+
         @defer.inlineCallbacks
         def competing_callback():
             with LoggingContext() as competing_context:
                 competing_context.request = "competing"
-                yield sleep(0)
+                yield clock.sleep(0)
                 self._check_test_key("competing")
 
         reactor.callLater(0, competing_callback)
 
         with LoggingContext() as context_one:
             context_one.request = "one"
-            yield sleep(0)
+            yield clock.sleep(0)
             self._check_test_key("one")
 
     def _test_run_in_background(self, function):
@@ -87,7 +88,7 @@ class LoggingContextTestCase(unittest.TestCase):
     def test_run_in_background_with_blocking_fn(self):
         @defer.inlineCallbacks
         def blocking_function():
-            yield sleep(0)
+            yield Clock(reactor).sleep(0)
 
         return self._test_run_in_background(blocking_function)
 
diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py
index 1a1a8412f2..297aebbfbe 100644
--- a/tests/util/test_logformatter.py
+++ b/tests/util/test_logformatter.py
@@ -15,6 +15,7 @@
 import sys
 
 from synapse.util.logformatter import LogFormatter
+
 from tests import unittest
 
 
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index dfb78cb8bd..9b36ef4482 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -14,12 +14,12 @@
 # limitations under the License.
 
 
-from .. import unittest
+from mock import Mock
 
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache
 
-from mock import Mock
+from .. import unittest
 
 
 class LruCacheTestCase(unittest.TestCase):
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index 1d745ae1a7..24194e3b25 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -14,10 +14,10 @@
 # limitations under the License.
 
 
-from tests import unittest
-
 from synapse.util.async import ReadWriteLock
 
+from tests import unittest
+
 
 class ReadWriteLockTestCase(unittest.TestCase):
 
diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py
index d3a8630c2f..0f5b32fcc0 100644
--- a/tests/util/test_snapshot_cache.py
+++ b/tests/util/test_snapshot_cache.py
@@ -14,10 +14,11 @@
 # limitations under the License.
 
 
-from .. import unittest
+from twisted.internet.defer import Deferred
 
 from synapse.util.caches.snapshot_cache import SnapshotCache
-from twisted.internet.defer import Deferred
+
+from .. import unittest
 
 
 class SnapshotCacheTestCase(unittest.TestCase):
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
new file mode 100644
index 0000000000..e3897c0d19
--- /dev/null
+++ b/tests/util/test_stream_change_cache.py
@@ -0,0 +1,199 @@
+from mock import patch
+
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from tests import unittest
+
+
+class StreamChangeCacheTests(unittest.TestCase):
+    """
+    Tests for StreamChangeCache.
+    """
+
+    def test_prefilled_cache(self):
+        """
+        Providing a prefilled cache to StreamChangeCache will result in a cache
+        with the prefilled-cache entered in.
+        """
+        cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2})
+        self.assertTrue(cache.has_entity_changed("user@foo.com", 1))
+
+    def test_has_entity_changed(self):
+        """
+        StreamChangeCache.entity_has_changed will mark entities as changed, and
+        has_entity_changed will observe the changed entities.
+        """
+        cache = StreamChangeCache("#test", 3)
+
+        cache.entity_has_changed("user@foo.com", 6)
+        cache.entity_has_changed("bar@baz.net", 7)
+
+        # If it's been changed after that stream position, return True
+        self.assertTrue(cache.has_entity_changed("user@foo.com", 4))
+        self.assertTrue(cache.has_entity_changed("bar@baz.net", 4))
+
+        # If it's been changed at that stream position, return False
+        self.assertFalse(cache.has_entity_changed("user@foo.com", 6))
+
+        # If there's no changes after that stream position, return False
+        self.assertFalse(cache.has_entity_changed("user@foo.com", 7))
+
+        # If the entity does not exist, return False.
+        self.assertFalse(cache.has_entity_changed("not@here.website", 7))
+
+        # If we request before the stream cache's earliest known position,
+        # return True, whether it's a known entity or not.
+        self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
+        self.assertTrue(cache.has_entity_changed("not@here.website", 0))
+
+    @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0)
+    def test_has_entity_changed_pops_off_start(self):
+        """
+        StreamChangeCache.entity_has_changed will respect the max size and
+        purge the oldest items upon reaching that max size.
+        """
+        cache = StreamChangeCache("#test", 1, max_size=2)
+
+        cache.entity_has_changed("user@foo.com", 2)
+        cache.entity_has_changed("bar@baz.net", 3)
+        cache.entity_has_changed("user@elsewhere.org", 4)
+
+        # The cache is at the max size, 2
+        self.assertEqual(len(cache._cache), 2)
+
+        # The oldest item has been popped off
+        self.assertTrue("user@foo.com" not in cache._entity_to_key)
+
+        # If we update an existing entity, it keeps the two existing entities
+        cache.entity_has_changed("bar@baz.net", 5)
+        self.assertEqual(
+            set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key)
+        )
+
+    def test_get_all_entities_changed(self):
+        """
+        StreamChangeCache.get_all_entities_changed will return all changed
+        entities since the given position.  If the position is before the start
+        of the known stream, it returns None instead.
+        """
+        cache = StreamChangeCache("#test", 1)
+
+        cache.entity_has_changed("user@foo.com", 2)
+        cache.entity_has_changed("bar@baz.net", 3)
+        cache.entity_has_changed("user@elsewhere.org", 4)
+
+        self.assertEqual(
+            cache.get_all_entities_changed(1),
+            ["user@foo.com", "bar@baz.net", "user@elsewhere.org"],
+        )
+        self.assertEqual(
+            cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"]
+        )
+        self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
+        self.assertEqual(cache.get_all_entities_changed(0), None)
+
+    def test_has_any_entity_changed(self):
+        """
+        StreamChangeCache.has_any_entity_changed will return True if any
+        entities have been changed since the provided stream position, and
+        False if they have not.  If the cache has entries and the provided
+        stream position is before it, it will return True, otherwise False if
+        the cache has no entries.
+        """
+        cache = StreamChangeCache("#test", 1)
+
+        # With no entities, it returns False for the past, present, and future.
+        self.assertFalse(cache.has_any_entity_changed(0))
+        self.assertFalse(cache.has_any_entity_changed(1))
+        self.assertFalse(cache.has_any_entity_changed(2))
+
+        # We add an entity
+        cache.entity_has_changed("user@foo.com", 2)
+
+        # With an entity, it returns True for the past, the stream start
+        # position, and False for the stream position the entity was changed
+        # on and ones after it.
+        self.assertTrue(cache.has_any_entity_changed(0))
+        self.assertTrue(cache.has_any_entity_changed(1))
+        self.assertFalse(cache.has_any_entity_changed(2))
+        self.assertFalse(cache.has_any_entity_changed(3))
+
+    def test_get_entities_changed(self):
+        """
+        StreamChangeCache.get_entities_changed will return the entities in the
+        given list that have changed since the provided stream ID.  If the
+        stream position is earlier than the earliest known position, it will
+        return all of the entities queried for.
+        """
+        cache = StreamChangeCache("#test", 1)
+
+        cache.entity_has_changed("user@foo.com", 2)
+        cache.entity_has_changed("bar@baz.net", 3)
+        cache.entity_has_changed("user@elsewhere.org", 4)
+
+        # Query all the entries, but mid-way through the stream. We should only
+        # get the ones after that point.
+        self.assertEqual(
+            cache.get_entities_changed(
+                ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2
+            ),
+            set(["bar@baz.net", "user@elsewhere.org"]),
+        )
+
+        # Query all the entries mid-way through the stream, but include one
+        # that doesn't exist in it. We should get back the one that doesn't
+        # exist, too.
+        self.assertEqual(
+            cache.get_entities_changed(
+                [
+                    "user@foo.com",
+                    "bar@baz.net",
+                    "user@elsewhere.org",
+                    "not@here.website",
+                ],
+                stream_pos=2,
+            ),
+            set(["bar@baz.net", "user@elsewhere.org", "not@here.website"]),
+        )
+
+        # Query all the entries, but before the first known point. We will get
+        # all the entries we queried for, including ones that don't exist.
+        self.assertEqual(
+            cache.get_entities_changed(
+                [
+                    "user@foo.com",
+                    "bar@baz.net",
+                    "user@elsewhere.org",
+                    "not@here.website",
+                ],
+                stream_pos=0,
+            ),
+            set(
+                [
+                    "user@foo.com",
+                    "bar@baz.net",
+                    "user@elsewhere.org",
+                    "not@here.website",
+                ]
+            ),
+        )
+
+    def test_max_pos(self):
+        """
+        StreamChangeCache.get_max_pos_of_last_change will return the most
+        recent point where the entity could have changed.  If the entity is not
+        known, the stream start is provided instead.
+        """
+        cache = StreamChangeCache("#test", 1)
+
+        cache.entity_has_changed("user@foo.com", 2)
+        cache.entity_has_changed("bar@baz.net", 3)
+        cache.entity_has_changed("user@elsewhere.org", 4)
+
+        # Known entities will return the point where they were changed.
+        self.assertEqual(cache.get_max_pos_of_last_change("user@foo.com"), 2)
+        self.assertEqual(cache.get_max_pos_of_last_change("bar@baz.net"), 3)
+        self.assertEqual(cache.get_max_pos_of_last_change("user@elsewhere.org"), 4)
+
+        # Unknown entities will return the stream start position.
+        self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), 1)
diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py
index 7ab578a185..a5f2261208 100644
--- a/tests/util/test_treecache.py
+++ b/tests/util/test_treecache.py
@@ -14,10 +14,10 @@
 # limitations under the License.
 
 
-from .. import unittest
-
 from synapse.util.caches.treecache import TreeCache
 
+from .. import unittest
+
 
 class TreeCacheTestCase(unittest.TestCase):
     def test_get_set_onelevel(self):
diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py
index fdb24a48b0..03201a4d9b 100644
--- a/tests/util/test_wheel_timer.py
+++ b/tests/util/test_wheel_timer.py
@@ -13,10 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .. import unittest
-
 from synapse.util.wheel_timer import WheelTimer
 
+from .. import unittest
+
 
 class WheelTimerTestCase(unittest.TestCase):
     def test_single_insert_fetch(self):