diff options
author | Amber Brown <hawkowl@atleastfornow.net> | 2018-08-14 20:56:23 +1000 |
---|---|---|
committer | Amber Brown <hawkowl@atleastfornow.net> | 2018-08-14 20:56:23 +1000 |
commit | 591bf87c6afcdb4e8978a275219dd10dba4efc25 (patch) | |
tree | a2620b679bcaa6ecf75388be0ed5e17fa350d2cd /tests/utils.py | |
parent | Fixes test_reap_monthly_active_users so it passes under postgres (diff) | |
parent | Implement a new test baseclass to cut down on boilerplate (#3684) (diff) | |
download | synapse-591bf87c6afcdb4e8978a275219dd10dba4efc25.tar.xz |
Merge remote-tracking branch 'origin/develop' into neilj/fix_reap_users_in_postgres
Diffstat (limited to 'tests/utils.py')
-rw-r--r-- | tests/utils.py | 133 |
1 files changed, 115 insertions, 18 deletions
diff --git a/tests/utils.py b/tests/utils.py index 8668b5478f..90378326f8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import atexit import hashlib +import os +import uuid from inspect import getcallargs from mock import Mock, patch @@ -27,23 +30,80 @@ from synapse.http.server import HttpServer from synapse.server import HomeServer from synapse.storage import PostgresEngine from synapse.storage.engines import create_engine -from synapse.storage.prepare_database import prepare_database +from synapse.storage.prepare_database import ( + _get_or_create_schema_state, + _setup_new_database, + prepare_database, +) from synapse.util.logcontext import LoggingContext from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. -# It requires you to have a local postgres database called synapse_test, within -# which ALL TABLES WILL BE DROPPED -USE_POSTGRES_FOR_TESTS = False +USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False) +POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres") +POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) + + +def setupdb(): + + # If we're using PostgreSQL, set up the db once + if USE_POSTGRES_FOR_TESTS: + pgconfig = { + "name": "psycopg2", + "args": { + "database": POSTGRES_BASE_DB, + "user": POSTGRES_USER, + "cp_min": 1, + "cp_max": 5, + }, + } + config = Mock() + config.password_providers = [] + config.database_config = pgconfig + db_engine = create_engine(pgconfig) + db_conn = db_engine.module.connect(user=POSTGRES_USER) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) + cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,)) + cur.close() + db_conn.close() + + # Set up in the db + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER + ) + cur = db_conn.cursor() + _get_or_create_schema_state(cur, db_engine) + _setup_new_database(cur, db_engine) + db_conn.commit() + cur.close() + db_conn.close() + + def _cleanup(): + db_conn = db_engine.module.connect(user=POSTGRES_USER) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) + cur.close() + db_conn.close() + + atexit.register(_cleanup) @defer.inlineCallbacks def setup_test_homeserver( - name="test", datastore=None, config=None, reactor=None, **kargs + cleanup_func, name="test", datastore=None, config=None, reactor=None, **kargs ): - """Setup a homeserver suitable for running tests against. Keyword arguments - are passed to the Homeserver constructor. If no datastore is supplied a - datastore backed by an in-memory sqlite db will be given to the HS. + """ + Setup a homeserver suitable for running tests against. Keyword arguments + are passed to the Homeserver constructor. + + If no datastore is supplied, one is created and given to the homeserver. + + Args: + cleanup_func : The function used to register a cleanup routine for + after the test. """ if reactor is None: from twisted.internet import reactor @@ -95,9 +155,11 @@ def setup_test_homeserver( kargs["clock"] = MockClock() if USE_POSTGRES_FOR_TESTS: + test_db = "synapse_test_%s" % uuid.uuid4().hex + config.database_config = { "name": "psycopg2", - "args": {"database": "synapse_test", "cp_min": 1, "cp_max": 5}, + "args": {"database": test_db, "cp_min": 1, "cp_max": 5}, } else: config.database_config = { @@ -107,6 +169,21 @@ def setup_test_homeserver( db_engine = create_engine(config.database_config) + # Create the database before we actually try and connect to it, based off + # the template database we generate in setupdb() + if datastore is None and isinstance(db_engine, PostgresEngine): + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + cur.execute( + "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) + ) + cur.close() + db_conn.close() + # we need to configure the connection pool to run the on_new_connection # function, so that we can test code that uses custom sqlite functions # (like rank). @@ -125,15 +202,35 @@ def setup_test_homeserver( reactor=reactor, **kargs ) - db_conn = hs.get_db_conn() - # make sure that the database is empty - if isinstance(db_engine, PostgresEngine): - cur = db_conn.cursor() - cur.execute("SELECT tablename FROM pg_tables where schemaname='public'") - rows = cur.fetchall() - for r in rows: - cur.execute("DROP TABLE %s CASCADE" % r[0]) - yield prepare_database(db_conn, db_engine, config) + + # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to + # date db + if not isinstance(db_engine, PostgresEngine): + db_conn = hs.get_db_conn() + yield prepare_database(db_conn, db_engine, config) + db_conn.commit() + db_conn.close() + + else: + # We need to do cleanup on PostgreSQL + def cleanup(): + # Close all the db pools + hs.get_db_pool().close() + + # Drop the test database + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + cur.close() + db_conn.close() + + # Register the cleanup hook + cleanup_func(cleanup) + hs.setup() else: hs = HomeServer( |