prepare_database() on db_conn, not plain name, so we can pass in the connection from outside
2 files changed, 35 insertions, 32 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index e6377e3060..2f1b954902 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -39,6 +39,7 @@ import logging
import os
import re
import sys
+import sqlite3
logger = logging.getLogger(__name__)
@@ -208,7 +209,14 @@ def setup():
redirect_root_to_web_client=True,
)
- prepare_database(hs.get_db_name())
+ db_name = hs.get_db_name()
+
+ logging.info("Preparing database: %s...", db_name)
+
+ with sqlite3.connect(db_name) as db_conn:
+ prepare_database(db_conn)
+
+ logging.info("Database prepared in %s.", db_name)
hs.get_db_pool()
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 2543fb12b7..6b273a0306 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -43,7 +43,6 @@ from .keys import KeyStore
import json
import logging
import os
-import sqlite3
logger = logging.getLogger(__name__)
@@ -370,44 +369,40 @@ def read_schema(schema):
return schema_file.read()
-def prepare_database(db_name):
+def prepare_database(db_conn):
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
don't have to worry about overwriting existing content.
"""
- logging.info("Preparing database: %s...", db_name)
+ c = db_conn.cursor()
+ c.execute("PRAGMA user_version")
+ row = c.fetchone()
- with sqlite3.connect(db_name) as db_conn:
- c = db_conn.cursor()
- c.execute("PRAGMA user_version")
- row = c.fetchone()
+ if row and row[0]:
+ user_version = row[0]
- if row and row[0]:
- user_version = row[0]
-
- if user_version > SCHEMA_VERSION:
- raise ValueError("Cannot use this database as it is too " +
- "new for the server to understand"
- )
- elif user_version < SCHEMA_VERSION:
- logging.info("Upgrading database from version %d",
- user_version
- )
+ if user_version > SCHEMA_VERSION:
+ raise ValueError("Cannot use this database as it is too " +
+ "new for the server to understand"
+ )
+ elif user_version < SCHEMA_VERSION:
+ logging.info("Upgrading database from version %d",
+ user_version
+ )
- # Run every version since after the current version.
- for v in range(user_version + 1, SCHEMA_VERSION + 1):
- sql_script = read_schema("delta/v%d" % (v))
- c.executescript(sql_script)
+ # Run every version since after the current version.
+ for v in range(user_version + 1, SCHEMA_VERSION + 1):
+ sql_script = read_schema("delta/v%d" % (v))
+ c.executescript(sql_script)
- db_conn.commit()
+ db_conn.commit()
- else:
- for sql_loc in SCHEMAS:
- sql_script = read_schema(sql_loc)
+ else:
+ for sql_loc in SCHEMAS:
+ sql_script = read_schema(sql_loc)
- c.executescript(sql_script)
- db_conn.commit()
- c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
+ c.executescript(sql_script)
+ db_conn.commit()
+ c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
- c.close()
+ c.close()
- logging.info("Database prepared in %s.", db_name)
|