diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index d675d8c8f9..b63ecd4b5f 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage import read_schema
+from synapse.storage import prepare_database
from synapse.server import HomeServer
@@ -36,7 +36,6 @@ from daemonize import Daemonize
import twisted.manhole.telnet
import logging
-import sqlite3
import os
import re
import sys
@@ -44,22 +43,6 @@ import sys
logger = logging.getLogger(__name__)
-SCHEMAS = [
- "transactions",
- "pdu",
- "users",
- "profiles",
- "presence",
- "im",
- "room_aliases",
-]
-
-
-# Remember to update this number every time an incompatible change is made to
-# database schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 3
-
-
class SynapseHomeServer(HomeServer):
def build_http_client(self):
@@ -80,52 +63,12 @@ class SynapseHomeServer(HomeServer):
)
def build_db_pool(self):
- """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
- don't have to worry about overwriting existing content.
- """
- logging.info("Preparing database: %s...", self.db_name)
-
- with sqlite3.connect(self.db_name) as db_conn:
- c = db_conn.cursor()
- c.execute("PRAGMA user_version")
- row = c.fetchone()
-
- if row and row[0]:
- user_version = row[0]
-
- if user_version > SCHEMA_VERSION:
- raise ValueError("Cannot use this database as it is too " +
- "new for the server to understand"
- )
- elif user_version < SCHEMA_VERSION:
- logging.info("Upgrading database from version %d",
- user_version
- )
-
- # Run every version since after the current version.
- for v in range(user_version + 1, SCHEMA_VERSION + 1):
- sql_script = read_schema("delta/v%d" % (v))
- c.executescript(sql_script)
-
- db_conn.commit()
-
- else:
- for sql_loc in SCHEMAS:
- sql_script = read_schema(sql_loc)
-
- c.executescript(sql_script)
- db_conn.commit()
- c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
-
- c.close()
-
- logging.info("Database prepared in %s.", self.db_name)
-
- pool = adbapi.ConnectionPool(
- 'sqlite3', self.db_name, check_same_thread=False,
- cp_min=1, cp_max=1)
-
- return pool
+ return adbapi.ConnectionPool(
+ "sqlite3", self.get_db_name(),
+ check_same_thread=False,
+ cp_min=1,
+ cp_max=1
+ )
def create_resource_tree(self, web_client, redirect_root_to_web_client):
"""Create the resource tree for this Home Server.
@@ -270,6 +213,8 @@ def setup():
)
hs.start_listening(config.bind_port, config.unsecure_port)
+ prepare_database(hs.get_db_name())
+
hs.get_db_pool()
if config.manhole:
diff --git a/synapse/server.py b/synapse/server.py
index 83368ea5a7..1ba13f3df2 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -57,6 +57,7 @@ class BaseHomeServer(object):
DEPENDENCIES = [
'clock',
'http_client',
+ 'db_name',
'db_pool',
'persistence_service',
'replication_layer',
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index ad2a484c16..2543fb12b7 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -43,10 +43,28 @@ from .keys import KeyStore
import json
import logging
import os
+import sqlite3
logger = logging.getLogger(__name__)
+
+SCHEMAS = [
+ "transactions",
+ "pdu",
+ "users",
+ "profiles",
+ "presence",
+ "im",
+ "room_aliases",
+]
+
+
+# Remember to update this number every time an incompatible change is made to
+# database schema files, so the users will be informed on server restarts.
+SCHEMA_VERSION = 3
+
+
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
something went wrong.
@@ -350,3 +368,46 @@ def read_schema(schema):
"""
with open(schema_path(schema)) as schema_file:
return schema_file.read()
+
+
+def prepare_database(db_name):
+ """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
+ don't have to worry about overwriting existing content.
+ """
+ logging.info("Preparing database: %s...", db_name)
+
+ with sqlite3.connect(db_name) as db_conn:
+ c = db_conn.cursor()
+ c.execute("PRAGMA user_version")
+ row = c.fetchone()
+
+ if row and row[0]:
+ user_version = row[0]
+
+ if user_version > SCHEMA_VERSION:
+ raise ValueError("Cannot use this database as it is too " +
+ "new for the server to understand"
+ )
+ elif user_version < SCHEMA_VERSION:
+ logging.info("Upgrading database from version %d",
+ user_version
+ )
+
+ # Run every version since after the current version.
+ for v in range(user_version + 1, SCHEMA_VERSION + 1):
+ sql_script = read_schema("delta/v%d" % (v))
+ c.executescript(sql_script)
+
+ db_conn.commit()
+
+ else:
+ for sql_loc in SCHEMAS:
+ sql_script = read_schema(sql_loc)
+
+ c.executescript(sql_script)
+ db_conn.commit()
+ c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
+
+ c.close()
+
+ logging.info("Database prepared in %s.", db_name)
|