summary refs log tree commit diff
diff options
context:
space:
mode:
authorPaul "LeoNerd" Evans <paul@matrix.org>2014-09-10 16:23:58 +0100
committerPaul "LeoNerd" Evans <paul@matrix.org>2014-09-10 16:23:58 +0100
commit55397f634770f2b91cd4567e6b40507944144b67 (patch)
treea056b26298b5766cc4459c20e1bbc7c8150759b6
parentMake sure not to open our TCP ports until /after/ the DB is nicely prepared r... (diff)
downloadsynapse-55397f634770f2b91cd4567e6b40507944144b67.tar.xz
prepare_database() on db_conn, not plain name, so we can pass in the connection from outside
-rwxr-xr-xsynapse/app/homeserver.py10
-rw-r--r--synapse/storage/__init__.py57
2 files changed, 35 insertions, 32 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index e6377e3060..2f1b954902 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -39,6 +39,7 @@ import logging
 import os
 import re
 import sys
+import sqlite3
 
 logger = logging.getLogger(__name__)
 
@@ -208,7 +209,14 @@ def setup():
         redirect_root_to_web_client=True,
     )
 
-    prepare_database(hs.get_db_name())
+    db_name = hs.get_db_name()
+
+    logging.info("Preparing database: %s...", db_name)
+
+    with sqlite3.connect(db_name) as db_conn:
+        prepare_database(db_conn)
+
+    logging.info("Database prepared in %s.", db_name)
 
     hs.get_db_pool()
 
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 2543fb12b7..6b273a0306 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -43,7 +43,6 @@ from .keys import KeyStore
 import json
 import logging
 import os
-import sqlite3
 
 
 logger = logging.getLogger(__name__)
@@ -370,44 +369,40 @@ def read_schema(schema):
         return schema_file.read()
 
 
-def prepare_database(db_name):
+def prepare_database(db_conn):
     """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
     don't have to worry about overwriting existing content.
     """
-    logging.info("Preparing database: %s...", db_name)
+    c = db_conn.cursor()
+    c.execute("PRAGMA user_version")
+    row = c.fetchone()
 
-    with sqlite3.connect(db_name) as db_conn:
-        c = db_conn.cursor()
-        c.execute("PRAGMA user_version")
-        row = c.fetchone()
+    if row and row[0]:
+        user_version = row[0]
 
-        if row and row[0]:
-            user_version = row[0]
-
-            if user_version > SCHEMA_VERSION:
-                raise ValueError("Cannot use this database as it is too " +
-                    "new for the server to understand"
-                )
-            elif user_version < SCHEMA_VERSION:
-                logging.info("Upgrading database from version %d",
-                    user_version
-                )
+        if user_version > SCHEMA_VERSION:
+            raise ValueError("Cannot use this database as it is too " +
+                "new for the server to understand"
+            )
+        elif user_version < SCHEMA_VERSION:
+            logging.info("Upgrading database from version %d",
+                user_version
+            )
 
-                # Run every version since after the current version.
-                for v in range(user_version + 1, SCHEMA_VERSION + 1):
-                    sql_script = read_schema("delta/v%d" % (v))
-                    c.executescript(sql_script)
+            # Run every version since after the current version.
+            for v in range(user_version + 1, SCHEMA_VERSION + 1):
+                sql_script = read_schema("delta/v%d" % (v))
+                c.executescript(sql_script)
 
-                db_conn.commit()
+            db_conn.commit()
 
-        else:
-            for sql_loc in SCHEMAS:
-                sql_script = read_schema(sql_loc)
+    else:
+        for sql_loc in SCHEMAS:
+            sql_script = read_schema(sql_loc)
 
-                c.executescript(sql_script)
-            db_conn.commit()
-            c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
+            c.executescript(sql_script)
+        db_conn.commit()
+        c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
 
-        c.close()
+    c.close()
 
-    logging.info("Database prepared in %s.", db_name)