summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-xsynapse/app/homeserver.py36
1 files changed, 19 insertions, 17 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index f210d26629..d63afd1b4a 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -43,6 +43,17 @@ import re
 logger = logging.getLogger(__name__)
 
 
+SCHEMAS = [
+    "transactions",
+    "pdu",
+    "users",
+    "profiles",
+    "presence",
+    "im",
+    "room_aliases",
+]
+
+
 class SynapseHomeServer(HomeServer):
 
     def build_http_client(self):
@@ -65,24 +76,11 @@ class SynapseHomeServer(HomeServer):
         don't have to worry about overwriting existing content.
         """
         logging.info("Preparing database: %s...", self.db_name)
-        pool = adbapi.ConnectionPool(
-            'sqlite3', self.db_name, check_same_thread=False,
-            cp_min=1, cp_max=1)
 
-        schemas = [
-            "transactions",
-            "pdu",
-            "users",
-            "profiles",
-            "presence",
-            "im",
-            "room_aliases",
-        ]
+        with sqlite3.connect(self.db_name) as db_conn:
+            for sql_loc in SCHEMAS:
+                sql_script = read_schema(sql_loc)
 
-        for sql_loc in schemas:
-            sql_script = read_schema(sql_loc)
-
-            with sqlite3.connect(self.db_name) as db_conn:
                 c = db_conn.cursor()
                 c.executescript(sql_script)
                 c.close()
@@ -90,6 +88,10 @@ class SynapseHomeServer(HomeServer):
 
         logging.info("Database prepared in %s.", self.db_name)
 
+        pool = adbapi.ConnectionPool(
+            'sqlite3', self.db_name, check_same_thread=False,
+            cp_min=1, cp_max=1)
+
         return pool
 
     def create_resource_tree(self, web_client, redirect_root_to_web_client):
@@ -282,7 +284,7 @@ def setup():
         redirect_root_to_web_client=True)
     hs.start_listening(args.port)
 
-    hs.build_db_pool()
+    hs.get_db_pool()
 
     if args.manhole:
         f = twisted.manhole.telnet.ShellFactory()