diff --git a/changelog.d/6513.misc b/changelog.d/6513.misc
new file mode 100644
index 0000000000..36700f5657
--- /dev/null
+++ b/changelog.d/6513.misc
@@ -0,0 +1 @@
+Remove all assumptions of there being a single phyiscal DB apart from the `synapse.config`.
diff --git a/scripts-dev/update_database b/scripts-dev/update_database
index 23017c21f8..1d62f0403a 100755
--- a/scripts-dev/update_database
+++ b/scripts-dev/update_database
@@ -26,7 +26,6 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.server import HomeServer
from synapse.storage import DataStore
-from synapse.storage.prepare_database import prepare_database
logger = logging.getLogger("update_database")
@@ -77,12 +76,8 @@ if __name__ == "__main__":
# Instantiate and initialise the homeserver object.
hs = MockHomeserver(config)
- db_conn = hs.get_db_conn()
- # Update the database to the latest schema.
- prepare_database(db_conn, hs.database_engine, config=config)
- db_conn.commit()
-
- # setup instantiates the store within the homeserver object.
+ # Setup instantiates the store within the homeserver object and updates the
+ # DB.
hs.setup()
store = hs.get_datastore()
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index e393a9b2f7..5b5368988c 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -30,6 +30,7 @@ import yaml
from twisted.enterprise import adbapi
from twisted.internet import defer, reactor
+from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.logging.context import PreserveLoggingContext
from synapse.storage._base import LoggingTransaction
@@ -55,7 +56,7 @@ from synapse.storage.data_stores.main.stats import StatsStore
from synapse.storage.data_stores.main.user_directory import (
UserDirectoryBackgroundUpdateStore,
)
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_conn
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.util import Clock
@@ -165,23 +166,17 @@ class Store(
class MockHomeserver:
- def __init__(self, config, database_engine, db_conn, db_pool):
- self.database_engine = database_engine
- self.db_conn = db_conn
- self.db_pool = db_pool
+ def __init__(self, config):
self.clock = Clock(reactor)
self.config = config
self.hostname = config.server_name
- def get_db_conn(self):
- return self.db_conn
-
- def get_db_pool(self):
- return self.db_pool
-
def get_clock(self):
return self.clock
+ def get_reactor(self):
+ return reactor
+
class Porter(object):
def __init__(self, **kwargs):
@@ -445,45 +440,36 @@ class Porter(object):
else:
return
- def setup_db(self, db_config, database_engine):
- db_conn = database_engine.module.connect(
- **{
- k: v
- for k, v in db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- )
-
- prepare_database(db_conn, database_engine, config=None)
+ def setup_db(self, db_config: DatabaseConnectionConfig, engine):
+ db_conn = make_conn(db_config, engine)
+ prepare_database(db_conn, engine, config=None)
db_conn.commit()
return db_conn
@defer.inlineCallbacks
- def build_db_store(self, config):
+ def build_db_store(self, db_config: DatabaseConnectionConfig):
"""Builds and returns a database store using the provided configuration.
Args:
- config: The database configuration, i.e. a dict following the structure of
- the "database" section of Synapse's configuration file.
+ config: The database configuration
Returns:
The built Store object.
"""
- engine = create_engine(config)
-
- self.progress.set_state("Preparing %s" % config["name"])
- conn = self.setup_db(config, engine)
+ self.progress.set_state("Preparing %s" % db_config.config["name"])
- db_pool = adbapi.ConnectionPool(config["name"], **config["args"])
+ engine = create_engine(db_config.config)
+ conn = self.setup_db(db_config, engine)
- hs = MockHomeserver(self.hs_config, engine, conn, db_pool)
+ hs = MockHomeserver(self.hs_config)
- store = Store(Database(hs), conn, hs)
+ store = Store(Database(hs, db_config, engine), conn, hs)
yield store.db.runInteraction(
- "%s_engine.check_database" % config["name"], engine.check_database,
+ "%s_engine.check_database" % db_config.config["name"],
+ engine.check_database,
)
return store
@@ -509,7 +495,11 @@ class Porter(object):
@defer.inlineCallbacks
def run(self):
try:
- self.sqlite_store = yield self.build_db_store(self.sqlite_config)
+ self.sqlite_store = yield self.build_db_store(
+ DatabaseConnectionConfig(
+ "master", self.sqlite_config, data_stores=["main"]
+ )
+ )
# Check if all background updates are done, abort if not.
updates_complete = (
@@ -524,7 +514,7 @@ class Porter(object):
defer.returnValue(None)
self.postgres_store = yield self.build_db_store(
- self.hs_config.database_config
+ self.hs_config.get_single_database()
)
yield self.run_background_updates_on_postgres()
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 0e2509f0b1..5f2f3c7cfd 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -12,12 +12,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import os
from textwrap import indent
+from typing import List
import yaml
-from ._base import Config
+from synapse.config._base import Config, ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+class DatabaseConnectionConfig:
+ """Contains the connection config for a particular database.
+
+ Args:
+ name: A label for the database, used for logging.
+ db_config: The config for a particular database, as per `database`
+ section of main config. Has two fields: `name` for database
+ module name, and `args` for the args to give to the database
+ connector.
+ data_stores: The list of data stores that should be provisioned on the
+ database.
+ """
+
+ def __init__(self, name: str, db_config: dict, data_stores: List[str]):
+ if db_config["name"] not in ("sqlite3", "psycopg2"):
+ raise ConfigError("Unsupported database type %r" % (db_config["name"],))
+
+ if db_config["name"] == "sqlite3":
+ db_config.setdefault("args", {}).update(
+ {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
+ )
+
+ self.name = name
+ self.config = db_config
+ self.data_stores = data_stores
class DatabaseConfig(Config):
@@ -26,20 +57,14 @@ class DatabaseConfig(Config):
def read_config(self, config, **kwargs):
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
- self.database_config = config.get("database")
+ database_config = config.get("database")
- if self.database_config is None:
- self.database_config = {"name": "sqlite3", "args": {}}
+ if database_config is None:
+ database_config = {"name": "sqlite3", "args": {}}
- name = self.database_config.get("name", None)
- if name == "psycopg2":
- pass
- elif name == "sqlite3":
- self.database_config.setdefault("args", {}).update(
- {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
- )
- else:
- raise RuntimeError("Unsupported database type '%s'" % (name,))
+ self.databases = [
+ DatabaseConnectionConfig("master", database_config, data_stores=["main"])
+ ]
self.set_databasepath(config.get("database_path"))
@@ -76,11 +101,24 @@ class DatabaseConfig(Config):
self.set_databasepath(args.database_path)
def set_databasepath(self, database_path):
+ if database_path is None:
+ return
+
if database_path != ":memory:":
database_path = self.abspath(database_path)
- if self.database_config.get("name", None) == "sqlite3":
- if database_path is not None:
- self.database_config["args"]["database"] = database_path
+
+ # We only support setting a database path if we have a single sqlite3
+ # database.
+ if len(self.databases) != 1:
+ raise ConfigError("Cannot specify 'database_path' with multiple databases")
+
+ database = self.get_single_database()
+ if database.config["name"] != "sqlite3":
+ # We don't raise here as we haven't done so before for this case.
+ logger.warn("Ignoring 'database_path' for non-sqlite3 database")
+ return
+
+ database.config["args"]["database"] = database_path
@staticmethod
def add_arguments(parser):
@@ -91,3 +129,11 @@ class DatabaseConfig(Config):
metavar="SQLITE_DATABASE_PATH",
help="The path to a sqlite database to use.",
)
+
+ def get_single_database(self) -> DatabaseConnectionConfig:
+ """Returns the database if there is only one, useful for e.g. tests
+ """
+ if len(self.databases) != 1:
+ raise Exception("More than one database exists")
+
+ return self.databases[0]
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index eda15bc623..240c4add12 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -230,7 +230,7 @@ class PresenceHandler(object):
is some spurious presence changes that will self-correct.
"""
# If the DB pool has already terminated, don't try updating
- if not self.hs.get_db_pool().running:
+ if not self.store.database.is_running():
return
logger.info(
diff --git a/synapse/server.py b/synapse/server.py
index 5021068ce0..7926867b77 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -25,7 +25,6 @@ import abc
import logging
import os
-from twisted.enterprise import adbapi
from twisted.mail.smtp import sendmail
from twisted.web.client import BrowserLikePolicyForHTTPS
@@ -98,7 +97,6 @@ from synapse.server_notices.worker_server_notices_sender import (
)
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStores, Storage
-from synapse.storage.engines import create_engine
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
@@ -134,7 +132,6 @@ class HomeServer(object):
DEPENDENCIES = [
"http_client",
- "db_pool",
"federation_client",
"federation_server",
"handlers",
@@ -233,12 +230,6 @@ class HomeServer(object):
self.admin_redaction_ratelimiter = Ratelimiter()
self.registration_ratelimiter = Ratelimiter()
- self.database_engine = create_engine(config.database_config)
- config.database_config.setdefault("args", {})[
- "cp_openfun"
- ] = self.database_engine.on_new_connection
- self.db_config = config.database_config
-
self.datastores = None
# Other kwargs are explicit dependencies
@@ -247,10 +238,8 @@ class HomeServer(object):
def setup(self):
logger.info("Setting up.")
- with self.get_db_conn() as conn:
- self.datastores = DataStores(self.DATASTORE_CLASS, conn, self)
- conn.commit()
self.start_time = int(self.get_clock().time())
+ self.datastores = DataStores(self.DATASTORE_CLASS, self)
logger.info("Finished setting up.")
def setup_master(self):
@@ -284,6 +273,9 @@ class HomeServer(object):
def get_datastore(self):
return self.datastores.main
+ def get_datastores(self):
+ return self.datastores
+
def get_config(self):
return self.config
@@ -433,31 +425,6 @@ class HomeServer(object):
)
return MatrixFederationHttpClient(self, tls_client_options_factory)
- def build_db_pool(self):
- name = self.db_config["name"]
-
- return adbapi.ConnectionPool(
- name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {})
- )
-
- def get_db_conn(self, run_new_connection=True):
- """Makes a new connection to the database, skipping the db pool
-
- Returns:
- Connection: a connection object implementing the PEP-249 spec
- """
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v
- for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def build_media_repository_resource(self):
# build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b7637b5dc0..88546ad614 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -40,7 +40,7 @@ class SQLBaseStore(object):
def __init__(self, database: Database, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
- self.database_engine = hs.database_engine
+ self.database_engine = database.engine
self.db = database
self.rand = random.SystemRandom()
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index cafedd5c0d..0983e059c0 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -13,24 +13,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.database import Database
+import logging
+
+from synapse.storage.database import Database, make_conn
+from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
+logger = logging.getLogger(__name__)
+
class DataStores(object):
"""The various data stores.
These are low level interfaces to physical databases.
+
+ Attributes:
+ main (DataStore)
"""
- def __init__(self, main_store_class, db_conn, hs):
+ def __init__(self, main_store_class, hs):
# Note we pass in the main store class here as workers use a different main
# store.
- database = Database(hs)
- # Check that db is correctly configured.
- database.engine.check_database(db_conn.cursor())
+ self.databases = []
+
+ for database_config in hs.config.database.databases:
+ db_name = database_config.name
+ engine = create_engine(database_config.config)
+
+ with make_conn(database_config, engine) as db_conn:
+ logger.info("Preparing database %r...", db_name)
+
+ engine.check_database(db_conn.cursor())
+ prepare_database(
+ db_conn, engine, hs.config, data_stores=database_config.data_stores,
+ )
+
+ database = Database(hs, database_config, engine)
+
+ if "main" in database_config.data_stores:
+ logger.info("Starting 'main' data store")
+ self.main = main_store_class(database, db_conn, hs)
+
+ db_conn.commit()
- prepare_database(db_conn, database.engine, config=hs.config)
+ self.databases.append(database)
- self.main = main_store_class(database, db_conn, hs)
+ logger.info("Database %r prepared", db_name)
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index add3037b69..13f4c9c72e 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -412,7 +412,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
def _update_client_ips_batch(self):
# If the DB pool has already terminated, don't try updating
- if not self.hs.get_db_pool().running:
+ if not self.db.is_running():
return
to_update = self._batch_row_update
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ec19ae1d9d..1003dd84a5 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -24,9 +24,11 @@ from six.moves import intern, range
from prometheus_client import Histogram
+from twisted.enterprise import adbapi
from twisted.internet import defer
from synapse.api.errors import StoreError
+from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
@@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
}
+def make_pool(
+ reactor, db_config: DatabaseConnectionConfig, engine
+) -> adbapi.ConnectionPool:
+ """Get the connection pool for the database.
+ """
+
+ return adbapi.ConnectionPool(
+ db_config.config["name"],
+ cp_reactor=reactor,
+ cp_openfun=engine.on_new_connection,
+ **db_config.config.get("args", {})
+ )
+
+
+def make_conn(db_config: DatabaseConnectionConfig, engine):
+ """Make a new connection to the database and return it.
+
+ Returns:
+ Connection
+ """
+
+ db_params = {
+ k: v
+ for k, v in db_config.config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = engine.module.connect(**db_params)
+ engine.on_new_connection(db_conn)
+ return db_conn
+
+
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
@@ -218,10 +251,11 @@ class Database(object):
_TXN_ID = 0
- def __init__(self, hs):
+ def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
self.hs = hs
self._clock = hs.get_clock()
- self._db_pool = hs.get_db_pool()
+ self._database_config = database_config
+ self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
self.updates = BackgroundUpdater(hs, self)
@@ -234,7 +268,7 @@ class Database(object):
# to watch it
self._txn_perf_counters = PerformanceCounters()
- self.engine = hs.database_engine
+ self.engine = engine
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
@@ -255,6 +289,11 @@ class Database(object):
self._check_safe_to_upsert,
)
+ def is_running(self):
+ """Is the database pool currently running
+ """
+ return self._db_pool.running
+
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
"""
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index ddad17dc5a..df039a072d 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -16,8 +16,6 @@
import struct
import threading
-from synapse.storage.prepare_database import prepare_database
-
class Sqlite3Engine(object):
single_threaded = True
@@ -25,6 +23,9 @@ class Sqlite3Engine(object):
def __init__(self, database_module, database_config):
self.module = database_module
+ database = database_config.get("args", {}).get("database")
+ self._is_in_memory = database in (None, ":memory:",)
+
# The current max state_group, or None if we haven't looked
# in the DB yet.
self._current_state_group_id = None
@@ -59,7 +60,16 @@ class Sqlite3Engine(object):
return sql
def on_new_connection(self, db_conn):
- prepare_database(db_conn, self, config=None)
+
+ # We need to import here to avoid an import loop.
+ from synapse.storage.prepare_database import prepare_database
+
+ if self._is_in_memory:
+ # In memory databases need to be rebuilt each time. Ideally we'd
+ # reuse the same connection as we do when starting up, but that
+ # would involve using adbapi before we have started the reactor.
+ prepare_database(db_conn, self, config=None)
+
db_conn.create_function("rank", 1, _rank)
def is_deadlock(self, error):
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 731e1c9d9c..b4194b44ee 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -41,7 +41,7 @@ class UpgradeDatabaseException(PrepareDatabaseException):
pass
-def prepare_database(db_conn, database_engine, config):
+def prepare_database(db_conn, database_engine, config, data_stores=["main"]):
"""Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -54,11 +54,10 @@ def prepare_database(db_conn, database_engine, config):
config (synapse.config.homeserver.HomeServerConfig|None):
application config, or None if we are connecting to an existing
database which we expect to be configured already
+ data_stores (list[str]): The name of the data stores that will be used
+ with this database. Defaults to all data stores.
"""
- # For now we only have the one datastore.
- data_stores = ["main"]
-
try:
cur = db_conn.cursor()
version_info = _get_or_create_schema_state(cur, database_engine)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 92b8726093..596ddc6970 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -64,28 +64,29 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
+ datastores = Mock()
+ datastores.main = Mock(
+ spec=[
+ # Bits that Federation needs
+ "prep_send_transaction",
+ "delivered_txn",
+ "get_received_txn_response",
+ "set_received_txn_response",
+ "get_destination_retry_timings",
+ "get_devices_by_remote",
+ # Bits that user_directory needs
+ "get_user_directory_stream_pos",
+ "get_current_state_deltas",
+ "get_device_updates_by_remote",
+ ]
+ )
+
hs = self.setup_test_homeserver(
- datastore=(
- Mock(
- spec=[
- # Bits that Federation needs
- "prep_send_transaction",
- "delivered_txn",
- "get_received_txn_response",
- "set_received_txn_response",
- "get_destination_retry_timings",
- "get_device_updates_by_remote",
- # Bits that user_directory needs
- "get_user_directory_stream_pos",
- "get_current_state_deltas",
- ]
- )
- ),
- notifier=Mock(),
- http_client=mock_federation_client,
- keyring=mock_keyring,
+ notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
)
+ hs.datastores = datastores
+
return hs
def prepare(self, reactor, clock, hs):
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 3dae83c543..2a1e7c7166 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -20,7 +20,7 @@ from synapse.replication.tcp.client import (
ReplicationClientHandler,
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-from synapse.storage.database import Database
+from synapse.storage.database import make_conn
from tests import unittest
from tests.server import FakeTransport
@@ -41,10 +41,12 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
+ db_config = hs.config.database.get_single_database()
self.master_store = self.hs.get_datastore()
self.storage = hs.get_storage()
+ database = hs.get_datastores().databases[0]
self.slaved_store = self.STORE_TYPE(
- Database(hs), self.hs.get_db_conn(), self.hs
+ database, make_conn(db_config, database.engine), self.hs
)
self.event_id = 0
diff --git a/tests/server.py b/tests/server.py
index 2b7cf4242e..a554dfdd57 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -302,41 +302,42 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
Set up a synchronous test server, driven by the reactor used by
the homeserver.
"""
- d = _sth(cleanup_func, *args, **kwargs).result
+ server = _sth(cleanup_func, *args, **kwargs)
- if isinstance(d, Failure):
- d.raiseException()
+ database = server.config.database.get_single_database()
# Make the thread pool synchronous.
- clock = d.get_clock()
- pool = d.get_db_pool()
-
- def runWithConnection(func, *args, **kwargs):
- return threads.deferToThreadPool(
- pool._reactor,
- pool.threadpool,
- pool._runWithConnection,
- func,
- *args,
- **kwargs
- )
-
- def runInteraction(interaction, *args, **kwargs):
- return threads.deferToThreadPool(
- pool._reactor,
- pool.threadpool,
- pool._runInteraction,
- interaction,
- *args,
- **kwargs
- )
+ clock = server.get_clock()
+
+ for database in server.get_datastores().databases:
+ pool = database._db_pool
+
+ def runWithConnection(func, *args, **kwargs):
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runWithConnection,
+ func,
+ *args,
+ **kwargs
+ )
+
+ def runInteraction(interaction, *args, **kwargs):
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runInteraction,
+ interaction,
+ *args,
+ **kwargs
+ )
- if pool:
pool.runWithConnection = runWithConnection
pool.runInteraction = runInteraction
pool.threadpool = ThreadPool(clock._reactor)
pool.running = True
- return d
+
+ return server
def get_clock():
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 2e521e9ab7..fd52512696 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -28,7 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_conn
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -55,8 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
- database = Database(hs)
- self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ self.store = ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
def tearDown(self):
# TODO: suboptimal that we need to create files for tests!
@@ -111,9 +113,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
hs.config.event_cache_size = 1
hs.config.password_providers = []
- self.db_pool = hs.get_db_pool()
- self.engine = hs.database_engine
-
self.as_list = [
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
{"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
@@ -125,8 +124,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files = []
- database = Database(hs)
- self.store = TestTransactionStore(database, hs.get_db_conn(), hs)
+ # We assume there is only one database in these tests
+ database = hs.get_datastores().databases[0]
+ self.db_pool = database._db_pool
+ self.engine = database.engine
+
+ db_config = hs.config.get_single_database()
+ self.store = TestTransactionStore(
+ database, make_conn(db_config, self.engine), hs
+ )
def _add_service(self, url, as_token, id):
as_yaml = dict(
@@ -419,7 +425,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.event_cache_size = 1
hs.config.password_providers = []
- ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
@defer.inlineCallbacks
def test_duplicate_ids(self):
@@ -435,7 +444,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
e = cm.exception
self.assertIn(f1, str(e))
@@ -456,7 +468,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
e = cm.exception
self.assertIn(f1, str(e))
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 537cfe9f64..cdee0a9e60 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -52,15 +52,17 @@ class SQLBaseStoreTestCase(unittest.TestCase):
config = Mock()
config._disable_native_upserts = True
config.event_cache_size = 1
- config.database_config = {"name": "sqlite3"}
- engine = create_engine(config.database_config)
+ hs = TestHomeServer("test", config=config)
+
+ sqlite_config = {"name": "sqlite3"}
+ engine = create_engine(sqlite_config)
fake_engine = Mock(wraps=engine)
fake_engine.can_native_upsert = False
- hs = TestHomeServer(
- "test", db_pool=self.db_pool, config=config, database_engine=fake_engine
- )
- self.datastore = SQLBaseStore(Database(hs), None, hs)
+ db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+ db._db_pool = self.db_pool
+
+ self.datastore = SQLBaseStore(db, None, hs)
@defer.inlineCallbacks
def test_insert_1col(self):
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4578cc3b60..ed5786865a 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
- self.db_pool = hs.get_db_pool()
self.store = hs.get_datastore()
diff --git a/tests/utils.py b/tests/utils.py
index 585f305b9a..9f5bf40b4b 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -30,6 +30,7 @@ from twisted.internet import defer, reactor
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions
+from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server
@@ -177,7 +178,6 @@ class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
-@defer.inlineCallbacks
def setup_test_homeserver(
cleanup_func,
name="test",
@@ -214,7 +214,7 @@ def setup_test_homeserver(
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
- config.database_config = {
+ database_config = {
"name": "psycopg2",
"args": {
"database": test_db,
@@ -226,12 +226,15 @@ def setup_test_homeserver(
},
}
else:
- config.database_config = {
+ database_config = {
"name": "sqlite3",
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
}
- db_engine = create_engine(config.database_config)
+ database = DatabaseConnectionConfig("master", database_config, ["main"])
+ config.database.databases = [database]
+
+ db_engine = create_engine(database.config)
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
@@ -251,11 +254,6 @@ def setup_test_homeserver(
cur.close()
db_conn.close()
- # we need to configure the connection pool to run the on_new_connection
- # function, so that we can test code that uses custom sqlite functions
- # (like rank).
- config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
-
if datastore is None:
hs = homeserverToUse(
name,
@@ -267,21 +265,19 @@ def setup_test_homeserver(
**kargs
)
- # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
- # date db
- if not isinstance(db_engine, PostgresEngine):
- db_conn = hs.get_db_conn()
- yield prepare_database(db_conn, db_engine, config)
- db_conn.commit()
- db_conn.close()
+ hs.setup()
+ if homeserverToUse.__name__ == "TestHomeServer":
+ hs.setup_master()
+
+ if isinstance(db_engine, PostgresEngine):
+ database = hs.get_datastores().databases[0]
- else:
# We need to do cleanup on PostgreSQL
def cleanup():
import psycopg2
# Close all the db pools
- hs.get_db_pool().close()
+ database._db_pool.close()
dropped = False
@@ -320,23 +316,12 @@ def setup_test_homeserver(
# Register the cleanup hook
cleanup_func(cleanup)
- hs.setup()
- if homeserverToUse.__name__ == "TestHomeServer":
- hs.setup_master()
else:
- # If we have been given an explicit datastore we probably want to mock
- # out the DataStores somehow too. This all feels a bit wrong, but then
- # mocking the stores feels wrong too.
- datastores = Mock(datastore=datastore)
-
hs = homeserverToUse(
name,
- db_pool=None,
datastore=datastore,
- datastores=datastores,
config=config,
version_string="Synapse/tests",
- database_engine=db_engine,
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,
|