summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py57
1 files changed, 26 insertions, 31 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 2543fb12b7..6b273a0306 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -43,7 +43,6 @@ from .keys import KeyStore
 import json
 import logging
 import os
-import sqlite3
 
 
 logger = logging.getLogger(__name__)
@@ -370,44 +369,40 @@ def read_schema(schema):
         return schema_file.read()
 
 
-def prepare_database(db_name):
+def prepare_database(db_conn):
     """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
     don't have to worry about overwriting existing content.
     """
-    logging.info("Preparing database: %s...", db_name)
+    c = db_conn.cursor()
+    c.execute("PRAGMA user_version")
+    row = c.fetchone()
 
-    with sqlite3.connect(db_name) as db_conn:
-        c = db_conn.cursor()
-        c.execute("PRAGMA user_version")
-        row = c.fetchone()
+    if row and row[0]:
+        user_version = row[0]
 
-        if row and row[0]:
-            user_version = row[0]
-
-            if user_version > SCHEMA_VERSION:
-                raise ValueError("Cannot use this database as it is too " +
-                    "new for the server to understand"
-                )
-            elif user_version < SCHEMA_VERSION:
-                logging.info("Upgrading database from version %d",
-                    user_version
-                )
+        if user_version > SCHEMA_VERSION:
+            raise ValueError("Cannot use this database as it is too " +
+                "new for the server to understand"
+            )
+        elif user_version < SCHEMA_VERSION:
+            logging.info("Upgrading database from version %d",
+                user_version
+            )
 
-                # Run every version since after the current version.
-                for v in range(user_version + 1, SCHEMA_VERSION + 1):
-                    sql_script = read_schema("delta/v%d" % (v))
-                    c.executescript(sql_script)
+            # Run every version since after the current version.
+            for v in range(user_version + 1, SCHEMA_VERSION + 1):
+                sql_script = read_schema("delta/v%d" % (v))
+                c.executescript(sql_script)
 
-                db_conn.commit()
+            db_conn.commit()
 
-        else:
-            for sql_loc in SCHEMAS:
-                sql_script = read_schema(sql_loc)
+    else:
+        for sql_loc in SCHEMAS:
+            sql_script = read_schema(sql_loc)
 
-                c.executescript(sql_script)
-            db_conn.commit()
-            c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
+            c.executescript(sql_script)
+        db_conn.commit()
+        c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
 
-        c.close()
+    c.close()
 
-    logging.info("Database prepared in %s.", db_name)