diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
new file mode 100644
index 0000000000..dadfabd46d
--- /dev/null
+++ b/tests/util/caches/test_deferred_cache.py
@@ -0,0 +1,251 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+
+from twisted.internet import defer
+
+from synapse.util.caches.deferred_cache import DeferredCache
+
+from tests.unittest import TestCase
+
+
+class DeferredCacheTestCase(TestCase):
+ def test_empty(self):
+ cache = DeferredCache("test")
+ failed = False
+ try:
+ cache.get("foo")
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ def test_hit(self):
+ cache = DeferredCache("test")
+ cache.prefill("foo", 123)
+
+ self.assertEquals(self.successResultOf(cache.get("foo")), 123)
+
+ def test_hit_deferred(self):
+ cache = DeferredCache("test")
+ origin_d = defer.Deferred()
+ set_d = cache.set("k1", origin_d)
+
+ # get should return an incomplete deferred
+ get_d = cache.get("k1")
+ self.assertFalse(get_d.called)
+
+ # add a callback that will make sure that the set_d gets called before the get_d
+ def check1(r):
+ self.assertTrue(set_d.called)
+ return r
+
+ # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
+ # maybe we should fix that?
+ # get_d.addCallback(check1)
+
+ # now fire off all the deferreds
+ origin_d.callback(99)
+ self.assertEqual(self.successResultOf(origin_d), 99)
+ self.assertEqual(self.successResultOf(set_d), 99)
+ self.assertEqual(self.successResultOf(get_d), 99)
+
+ def test_callbacks(self):
+ """Invalidation callbacks are called at the right time"""
+ cache = DeferredCache("test")
+ callbacks = set()
+
+ # start with an entry, with a callback
+ cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
+
+ # now replace that entry with a pending result
+ origin_d = defer.Deferred()
+ set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
+
+ # ... and also make a get request
+ get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
+
+ # we don't expect the invalidation callback for the original value to have
+ # been called yet, even though get() will now return a different result.
+ # I'm not sure if that is by design or not.
+ self.assertEqual(callbacks, set())
+
+ # now fire off all the deferreds
+ origin_d.callback(20)
+ self.assertEqual(self.successResultOf(set_d), 20)
+ self.assertEqual(self.successResultOf(get_d), 20)
+
+ # now the original invalidation callback should have been called, but none of
+ # the others
+ self.assertEqual(callbacks, {"prefill"})
+ callbacks.clear()
+
+ # another update should invalidate both the previous results
+ cache.prefill("k1", 30)
+ self.assertEqual(callbacks, {"set", "get"})
+
+ def test_set_fail(self):
+ cache = DeferredCache("test")
+ callbacks = set()
+
+ # start with an entry, with a callback
+ cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
+
+ # now replace that entry with a pending result
+ origin_d = defer.Deferred()
+ set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
+
+ # ... and also make a get request
+ get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
+
+ # none of the callbacks should have been called yet
+ self.assertEqual(callbacks, set())
+
+ # oh noes! fails!
+ e = Exception("oops")
+ origin_d.errback(e)
+ self.assertIs(self.failureResultOf(set_d, Exception).value, e)
+ self.assertIs(self.failureResultOf(get_d, Exception).value, e)
+
+ # the callbacks for the failed requests should have been called.
+ # I'm not sure if this is deliberate or not.
+ self.assertEqual(callbacks, {"get", "set"})
+ callbacks.clear()
+
+ # the old value should still be returned now?
+ get_d2 = cache.get("k1", callback=lambda: callbacks.add("get2"))
+ self.assertEqual(self.successResultOf(get_d2), 10)
+
+ # replacing the value now should run the callbacks for those requests
+ # which got the original result
+ cache.prefill("k1", 30)
+ self.assertEqual(callbacks, {"prefill", "get2"})
+
+ def test_get_immediate(self):
+ cache = DeferredCache("test")
+ d1 = defer.Deferred()
+ cache.set("key1", d1)
+
+ # get_immediate should return default
+ v = cache.get_immediate("key1", 1)
+ self.assertEqual(v, 1)
+
+ # now complete the set
+ d1.callback(2)
+
+ # get_immediate should return result
+ v = cache.get_immediate("key1", 1)
+ self.assertEqual(v, 2)
+
+ def test_invalidate(self):
+ cache = DeferredCache("test")
+ cache.prefill(("foo",), 123)
+ cache.invalidate(("foo",))
+
+ failed = False
+ try:
+ cache.get(("foo",))
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ def test_invalidate_all(self):
+ cache = DeferredCache("testcache")
+
+ callback_record = [False, False]
+
+ def record_callback(idx):
+ callback_record[idx] = True
+
+ # add a couple of pending entries
+ d1 = defer.Deferred()
+ cache.set("key1", d1, partial(record_callback, 0))
+
+ d2 = defer.Deferred()
+ cache.set("key2", d2, partial(record_callback, 1))
+
+ # lookup should return pending deferreds
+ self.assertFalse(cache.get("key1").called)
+ self.assertFalse(cache.get("key2").called)
+
+ # let one of the lookups complete
+ d2.callback("result2")
+
+ # now the cache will return a completed deferred
+ self.assertEqual(self.successResultOf(cache.get("key2")), "result2")
+
+ # now do the invalidation
+ cache.invalidate_all()
+
+ # lookup should fail
+ with self.assertRaises(KeyError):
+ cache.get("key1")
+ with self.assertRaises(KeyError):
+ cache.get("key2")
+
+ # both callbacks should have been callbacked
+ self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
+ self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
+
+ # letting the other lookup complete should do nothing
+ d1.callback("result1")
+ with self.assertRaises(KeyError):
+ cache.get("key1", None)
+
+ def test_eviction(self):
+ cache = DeferredCache(
+ "test", max_entries=2, apply_cache_factor_from_config=False
+ )
+
+ cache.prefill(1, "one")
+ cache.prefill(2, "two")
+ cache.prefill(3, "three") # 1 will be evicted
+
+ failed = False
+ try:
+ cache.get(1)
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ cache.get(2)
+ cache.get(3)
+
+ def test_eviction_lru(self):
+ cache = DeferredCache(
+ "test", max_entries=2, apply_cache_factor_from_config=False
+ )
+
+ cache.prefill(1, "one")
+ cache.prefill(2, "two")
+
+ # Now access 1 again, thus causing 2 to be least-recently used
+ cache.get(1)
+
+ cache.prefill(3, "three")
+
+ failed = False
+ try:
+ cache.get(2)
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ cache.get(1)
+ cache.get(3)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 677e925477..cf1e3203a4 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from functools import partial
+from typing import Set
import mock
@@ -29,60 +29,50 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, lru_cache
from tests import unittest
+from tests.test_utils import get_awaitable_result
logger = logging.getLogger(__name__)
-def run_on_reactor():
- d = defer.Deferred()
- reactor.callLater(0, d.callback, 0)
- return make_deferred_yieldable(d)
-
-
-class CacheTestCase(unittest.TestCase):
- def test_invalidate_all(self):
- cache = descriptors.Cache("testcache")
-
- callback_record = [False, False]
-
- def record_callback(idx):
- callback_record[idx] = True
-
- # add a couple of pending entries
- d1 = defer.Deferred()
- cache.set("key1", d1, partial(record_callback, 0))
-
- d2 = defer.Deferred()
- cache.set("key2", d2, partial(record_callback, 1))
-
- # lookup should return observable deferreds
- self.assertFalse(cache.get("key1").has_called())
- self.assertFalse(cache.get("key2").has_called())
+class LruCacheDecoratorTestCase(unittest.TestCase):
+ def test_base(self):
+ class Cls:
+ def __init__(self):
+ self.mock = mock.Mock()
- # let one of the lookups complete
- d2.callback("result2")
+ @lru_cache()
+ def fn(self, arg1, arg2):
+ return self.mock(arg1, arg2)
- # for now at least, the cache will return real results rather than an
- # observabledeferred
- self.assertEqual(cache.get("key2"), "result2")
+ obj = Cls()
+ obj.mock.return_value = "fish"
+ r = obj.fn(1, 2)
+ self.assertEqual(r, "fish")
+ obj.mock.assert_called_once_with(1, 2)
+ obj.mock.reset_mock()
- # now do the invalidation
- cache.invalidate_all()
+ # a call with different params should call the mock again
+ obj.mock.return_value = "chips"
+ r = obj.fn(1, 3)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_called_once_with(1, 3)
+ obj.mock.reset_mock()
- # lookup should return none
- self.assertIsNone(cache.get("key1", None))
- self.assertIsNone(cache.get("key2", None))
+ # the two values should now be cached
+ r = obj.fn(1, 2)
+ self.assertEqual(r, "fish")
+ r = obj.fn(1, 3)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_not_called()
- # both callbacks should have been callbacked
- self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
- self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
- # letting the other lookup complete should do nothing
- d1.callback("result1")
- self.assertIsNone(cache.get("key1", None))
+def run_on_reactor():
+ d = defer.Deferred()
+ reactor.callLater(0, d.callback, 0)
+ return make_deferred_yieldable(d)
class DescriptorTestCase(unittest.TestCase):
@@ -174,6 +164,57 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
+ def test_cache_with_async_exception(self):
+ """The wrapped function returns a failure
+ """
+
+ class Cls:
+ result = None
+ call_count = 0
+
+ @cached()
+ def fn(self, arg1):
+ self.call_count += 1
+ return self.result
+
+ obj = Cls()
+ callbacks = set() # type: Set[str]
+
+ # set off an asynchronous request
+ obj.result = origin_d = defer.Deferred()
+
+ d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
+ self.assertFalse(d1.called)
+
+ # a second request should also return a deferred, but should not call the
+ # function itself.
+ d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2"))
+ self.assertFalse(d2.called)
+ self.assertEqual(obj.call_count, 1)
+
+ # no callbacks yet
+ self.assertEqual(callbacks, set())
+
+ # the original request fails
+ e = Exception("bzz")
+ origin_d.errback(e)
+
+ # ... which should cause the lookups to fail similarly
+ self.assertIs(self.failureResultOf(d1, Exception).value, e)
+ self.assertIs(self.failureResultOf(d2, Exception).value, e)
+
+ # ... and the callbacks to have been, uh, called.
+ self.assertEqual(callbacks, {"d1", "d2"})
+
+ # ... leaving the cache empty
+ self.assertEqual(len(obj.fn.cache.cache), 0)
+
+ # and a second call should work as normal
+ obj.result = defer.succeed(100)
+ d3 = obj.fn(1)
+ self.assertEqual(self.successResultOf(d3), 100)
+ self.assertEqual(obj.call_count, 2)
+
def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""
@@ -354,6 +395,260 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
+ def test_invalidate_cascade(self):
+ """Invalidations should cascade up through cache contexts"""
+
+ class Cls:
+ @cached(cache_context=True)
+ async def func1(self, key, cache_context):
+ return await self.func2(key, on_invalidate=cache_context.invalidate)
+
+ @cached(cache_context=True)
+ async def func2(self, key, cache_context):
+ return self.func3(key, on_invalidate=cache_context.invalidate)
+
+ @lru_cache(cache_context=True)
+ def func3(self, key, cache_context):
+ self.invalidate = cache_context.invalidate
+ return 42
+
+ obj = Cls()
+
+ top_invalidate = mock.Mock()
+ r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate))
+ self.assertEqual(r, 42)
+ obj.invalidate()
+ top_invalidate.assert_called_once()
+
+
+class CacheDecoratorTestCase(unittest.HomeserverTestCase):
+ """More tests for @cached
+
+ The following is a set of tests that got lost in a different file for a while.
+
+ There are probably duplicates of the tests in DescriptorTestCase. Ideally the
+ duplicates would be removed and the two sets of classes combined.
+ """
+
+ @defer.inlineCallbacks
+ def test_passthrough(self):
+ class A:
+ @cached()
+ def func(self, key):
+ return key
+
+ a = A()
+
+ self.assertEquals((yield a.func("foo")), "foo")
+ self.assertEquals((yield a.func("bar")), "bar")
+
+ @defer.inlineCallbacks
+ def test_hit(self):
+ callcount = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ a = A()
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 1)
+
+ self.assertEquals((yield a.func("foo")), "foo")
+ self.assertEquals(callcount[0], 1)
+
+ @defer.inlineCallbacks
+ def test_invalidate(self):
+ callcount = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ a = A()
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 1)
+
+ a.func.invalidate(("foo",))
+
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+
+ def test_invalidate_missing(self):
+ class A:
+ @cached()
+ def func(self, key):
+ return key
+
+ A().func.invalidate(("what",))
+
+ @defer.inlineCallbacks
+ def test_max_entries(self):
+ callcount = [0]
+
+ class A:
+ @cached(max_entries=10)
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ a = A()
+
+ for k in range(0, 12):
+ yield a.func(k)
+
+ self.assertEquals(callcount[0], 12)
+
+ # There must have been at least 2 evictions, meaning if we calculate
+ # all 12 values again, we must get called at least 2 more times
+ for k in range(0, 12):
+ yield a.func(k)
+
+ self.assertTrue(
+ callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
+ )
+
+ def test_prefill(self):
+ callcount = [0]
+
+ d = defer.succeed(123)
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return d
+
+ a = A()
+
+ a.func.prefill(("foo",), 456)
+
+ self.assertEquals(a.func("foo").result, 456)
+ self.assertEquals(callcount[0], 0)
+
+ @defer.inlineCallbacks
+ def test_invalidate_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func.invalidate(("foo",))
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 1)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ @defer.inlineCallbacks
+ def test_eviction_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A:
+ @cached(max_entries=2)
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+ yield a.func2("foo2")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func("foo3")
+
+ self.assertEquals(callcount[0], 3)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 4)
+ self.assertEquals(callcount2[0], 3)
+
+ @defer.inlineCallbacks
+ def test_double_get(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+
+ yield a.func2("foo")
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 2)
+
+ a.func.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 3)
+
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 0adb2174af..a739a6aaaf 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -19,7 +19,8 @@ from mock import Mock
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
-from .. import unittest
+from tests import unittest
+from tests.unittest import override_config
class LruCacheTestCase(unittest.HomeserverTestCase):
@@ -59,7 +60,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEquals(cache.pop("key"), None)
def test_del_multi(self):
- cache = LruCache(4, 2, cache_type=TreeCache)
+ cache = LruCache(4, keylen=2, cache_type=TreeCache)
cache[("animal", "cat")] = "mew"
cache[("animal", "dog")] = "woof"
cache[("vehicles", "car")] = "vroom"
@@ -83,6 +84,11 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
cache.clear()
self.assertEquals(len(cache), 0)
+ @override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
+ def test_special_size(self):
+ cache = LruCache(10, "mycache")
+ self.assertEqual(cache.max_size, 100)
+
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_get(self):
@@ -160,7 +166,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
m2 = Mock()
m3 = Mock()
m4 = Mock()
- cache = LruCache(4, 2, cache_type=TreeCache)
+ cache = LruCache(4, keylen=2, cache_type=TreeCache)
cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", callbacks=[m2])
diff --git a/tests/utils.py b/tests/utils.py
index 4673872f88..c8d3ffbaba 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -21,6 +21,7 @@ import time
import uuid
import warnings
from inspect import getcallargs
+from typing import Type
from urllib import parse as urlparse
from mock import Mock, patch
@@ -38,6 +39,7 @@ from synapse.http.server import HttpServer
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage import DataStore
+from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -88,6 +90,7 @@ def setupdb():
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
+ db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
prepare_database(db_conn, db_engine, None)
db_conn.close()
@@ -190,11 +193,10 @@ class TestHomeServer(HomeServer):
def setup_test_homeserver(
cleanup_func,
name="test",
- datastore=None,
config=None,
reactor=None,
- homeserverToUse=TestHomeServer,
- **kargs
+ homeserver_to_use: Type[HomeServer] = TestHomeServer,
+ **kwargs
):
"""
Setup a homeserver suitable for running tests against. Keyword arguments
@@ -217,8 +219,8 @@ def setup_test_homeserver(
config.ldap_enabled = False
- if "clock" not in kargs:
- kargs["clock"] = MockClock()
+ if "clock" not in kwargs:
+ kwargs["clock"] = MockClock()
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
@@ -247,7 +249,7 @@ def setup_test_homeserver(
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
- if datastore is None and isinstance(db_engine, PostgresEngine):
+ if isinstance(db_engine, PostgresEngine):
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
@@ -263,79 +265,68 @@ def setup_test_homeserver(
cur.close()
db_conn.close()
- if datastore is None:
- hs = homeserverToUse(
- name,
- config=config,
- version_string="Synapse/tests",
- tls_server_context_factory=Mock(),
- tls_client_options_factory=Mock(),
- reactor=reactor,
- **kargs
- )
+ hs = homeserver_to_use(
+ name, config=config, version_string="Synapse/tests", reactor=reactor,
+ )
- hs.setup()
- if homeserverToUse.__name__ == "TestHomeServer":
- hs.setup_master()
+ # Install @cache_in_self attributes
+ for key, val in kwargs.items():
+ setattr(hs, "_" + key, val)
- if isinstance(db_engine, PostgresEngine):
- database = hs.get_datastores().databases[0]
+ # Mock TLS
+ hs.tls_server_context_factory = Mock()
+ hs.tls_client_options_factory = Mock()
- # We need to do cleanup on PostgreSQL
- def cleanup():
- import psycopg2
+ hs.setup()
+ if homeserver_to_use == TestHomeServer:
+ hs.setup_background_tasks()
- # Close all the db pools
- database._db_pool.close()
+ if isinstance(db_engine, PostgresEngine):
+ database = hs.get_datastores().databases[0]
- dropped = False
+ # We need to do cleanup on PostgreSQL
+ def cleanup():
+ import psycopg2
- # Drop the test database
- db_conn = db_engine.module.connect(
- database=POSTGRES_BASE_DB,
- user=POSTGRES_USER,
- host=POSTGRES_HOST,
- password=POSTGRES_PASSWORD,
- )
- db_conn.autocommit = True
- cur = db_conn.cursor()
-
- # Try a few times to drop the DB. Some things may hold on to the
- # database for a few more seconds due to flakiness, preventing
- # us from dropping it when the test is over. If we can't drop
- # it, warn and move on.
- for x in range(5):
- try:
- cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
- db_conn.commit()
- dropped = True
- except psycopg2.OperationalError as e:
- warnings.warn(
- "Couldn't drop old db: " + str(e), category=UserWarning
- )
- time.sleep(0.5)
-
- cur.close()
- db_conn.close()
-
- if not dropped:
- warnings.warn("Failed to drop old DB.", category=UserWarning)
-
- if not LEAVE_DB:
- # Register the cleanup hook
- cleanup_func(cleanup)
+ # Close all the db pools
+ database._db_pool.close()
- else:
- hs = homeserverToUse(
- name,
- datastore=datastore,
- config=config,
- version_string="Synapse/tests",
- tls_server_context_factory=Mock(),
- tls_client_options_factory=Mock(),
- reactor=reactor,
- **kargs
- )
+ dropped = False
+
+ # Drop the test database
+ db_conn = db_engine.module.connect(
+ database=POSTGRES_BASE_DB,
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
+ )
+ db_conn.autocommit = True
+ cur = db_conn.cursor()
+
+ # Try a few times to drop the DB. Some things may hold on to the
+ # database for a few more seconds due to flakiness, preventing
+ # us from dropping it when the test is over. If we can't drop
+ # it, warn and move on.
+ for x in range(5):
+ try:
+ cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+ db_conn.commit()
+ dropped = True
+ except psycopg2.OperationalError as e:
+ warnings.warn(
+ "Couldn't drop old db: " + str(e), category=UserWarning
+ )
+ time.sleep(0.5)
+
+ cur.close()
+ db_conn.close()
+
+ if not dropped:
+ warnings.warn("Failed to drop old DB.", category=UserWarning)
+
+ if not LEAVE_DB:
+ # Register the cleanup hook
+ cleanup_func(cleanup)
# bcrypt is far too slow to be doing in unit tests
# Need to let the HS build an auth handler and then mess with it
@@ -351,7 +342,7 @@ def setup_test_homeserver(
hs.get_auth_handler().validate_hash = validate_hash
- fed = kargs.get("resource_for_federation", None)
+ fed = kwargs.get("resource_for_federation", None)
if fed:
register_federation_servlets(hs, fed)
|