summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-xsynapse/app/homeserver.py73
-rw-r--r--synapse/server.py1
-rw-r--r--synapse/storage/__init__.py61
3 files changed, 71 insertions, 64 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index d675d8c8f9..b63ecd4b5f 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage import read_schema
+from synapse.storage import prepare_database
 
 from synapse.server import HomeServer
 
@@ -36,7 +36,6 @@ from daemonize import Daemonize
 import twisted.manhole.telnet
 
 import logging
-import sqlite3
 import os
 import re
 import sys
@@ -44,22 +43,6 @@ import sys
 logger = logging.getLogger(__name__)
 
 
-SCHEMAS = [
-    "transactions",
-    "pdu",
-    "users",
-    "profiles",
-    "presence",
-    "im",
-    "room_aliases",
-]
-
-
-# Remember to update this number every time an incompatible change is made to
-# database schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 3
-
-
 class SynapseHomeServer(HomeServer):
 
     def build_http_client(self):
@@ -80,52 +63,12 @@ class SynapseHomeServer(HomeServer):
         )
 
     def build_db_pool(self):
-        """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
-        don't have to worry about overwriting existing content.
-        """
-        logging.info("Preparing database: %s...", self.db_name)
-
-        with sqlite3.connect(self.db_name) as db_conn:
-            c = db_conn.cursor()
-            c.execute("PRAGMA user_version")
-            row = c.fetchone()
-
-            if row and row[0]:
-                user_version = row[0]
-
-                if user_version > SCHEMA_VERSION:
-                    raise ValueError("Cannot use this database as it is too " +
-                        "new for the server to understand"
-                    )
-                elif user_version < SCHEMA_VERSION:
-                    logging.info("Upgrading database from version %d",
-                        user_version
-                    )
-
-                    # Run every version since after the current version.
-                    for v in range(user_version + 1, SCHEMA_VERSION + 1):
-                        sql_script = read_schema("delta/v%d" % (v))
-                        c.executescript(sql_script)
-
-                    db_conn.commit()
-
-            else:
-                for sql_loc in SCHEMAS:
-                    sql_script = read_schema(sql_loc)
-
-                    c.executescript(sql_script)
-                db_conn.commit()
-                c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
-
-            c.close()
-
-        logging.info("Database prepared in %s.", self.db_name)
-
-        pool = adbapi.ConnectionPool(
-            'sqlite3', self.db_name, check_same_thread=False,
-            cp_min=1, cp_max=1)
-
-        return pool
+        return adbapi.ConnectionPool(
+            "sqlite3", self.get_db_name(),
+            check_same_thread=False,
+            cp_min=1,
+            cp_max=1
+        )
 
     def create_resource_tree(self, web_client, redirect_root_to_web_client):
         """Create the resource tree for this Home Server.
@@ -270,6 +213,8 @@ def setup():
     )
     hs.start_listening(config.bind_port, config.unsecure_port)
 
+    prepare_database(hs.get_db_name())
+
     hs.get_db_pool()
 
     if config.manhole:
diff --git a/synapse/server.py b/synapse/server.py
index 83368ea5a7..1ba13f3df2 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -57,6 +57,7 @@ class BaseHomeServer(object):
     DEPENDENCIES = [
         'clock',
         'http_client',
+        'db_name',
         'db_pool',
         'persistence_service',
         'replication_layer',
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index ad2a484c16..2543fb12b7 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -43,10 +43,28 @@ from .keys import KeyStore
 import json
 import logging
 import os
+import sqlite3
 
 
 logger = logging.getLogger(__name__)
 
+
+SCHEMAS = [
+    "transactions",
+    "pdu",
+    "users",
+    "profiles",
+    "presence",
+    "im",
+    "room_aliases",
+]
+
+
+# Remember to update this number every time an incompatible change is made to
+# database schema files, so the users will be informed on server restarts.
+SCHEMA_VERSION = 3
+
+
 class _RollbackButIsFineException(Exception):
     """ This exception is used to rollback a transaction without implying
     something went wrong.
@@ -350,3 +368,46 @@ def read_schema(schema):
     """
     with open(schema_path(schema)) as schema_file:
         return schema_file.read()
+
+
+def prepare_database(db_name):
+    """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
+    don't have to worry about overwriting existing content.
+    """
+    logging.info("Preparing database: %s...", db_name)
+
+    with sqlite3.connect(db_name) as db_conn:
+        c = db_conn.cursor()
+        c.execute("PRAGMA user_version")
+        row = c.fetchone()
+
+        if row and row[0]:
+            user_version = row[0]
+
+            if user_version > SCHEMA_VERSION:
+                raise ValueError("Cannot use this database as it is too " +
+                    "new for the server to understand"
+                )
+            elif user_version < SCHEMA_VERSION:
+                logging.info("Upgrading database from version %d",
+                    user_version
+                )
+
+                # Run every version since after the current version.
+                for v in range(user_version + 1, SCHEMA_VERSION + 1):
+                    sql_script = read_schema("delta/v%d" % (v))
+                    c.executescript(sql_script)
+
+                db_conn.commit()
+
+        else:
+            for sql_loc in SCHEMAS:
+                sql_script = read_schema(sql_loc)
+
+                c.executescript(sql_script)
+            db_conn.commit()
+            c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
+
+        c.close()
+
+    logging.info("Database prepared in %s.", db_name)