summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/server.py3
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/background_updates.py2
-rw-r--r--synapse/storage/data_stores/__init__.py7
-rw-r--r--synapse/storage/database.py38
5 files changed, 24 insertions, 28 deletions
diff --git a/synapse/server.py b/synapse/server.py
index be9af7f986..2db3dab221 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -238,8 +238,7 @@ class HomeServer(object):
     def setup(self):
         logger.info("Setting up.")
         with self.get_db_conn() as conn:
-            datastore = self.DATASTORE_CLASS(conn, self)
-            self.datastores = DataStores(datastore, conn, self)
+            self.datastores = DataStores(self.DATASTORE_CLASS, conn, self)
             conn.commit()
         self.start_time = int(self.get_clock().time())
         logger.info("Finished setting up.")
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index f9e7f9a71e..b7637b5dc0 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -41,7 +41,7 @@ class SQLBaseStore(object):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = hs.database_engine
-        self.db = Database(hs)  # In future this will be passed in
+        self.db = database
         self.rand = random.SystemRandom()
 
     def _invalidate_state_caches(self, room_id, members_changed):
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index a9a13a2658..4f97fd5ab6 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -379,7 +379,7 @@ class BackgroundUpdater(object):
             logger.debug("[SQL] %s", sql)
             c.execute(sql)
 
-        if isinstance(self.db.database_engine, engines.PostgresEngine):
+        if isinstance(self.db.engine, engines.PostgresEngine):
             runner = create_index_psql
         elif psql_only:
             runner = None
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index cb184a98cc..79ecc62735 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -13,6 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.storage.database import Database
+
 
 class DataStores(object):
     """The various data stores.
@@ -20,7 +22,8 @@ class DataStores(object):
     These are low level interfaces to physical databases.
     """
 
-    def __init__(self, main_store, db_conn, hs):
+    def __init__(self, main_store_class, db_conn, hs):
         # Note we pass in the main store here as workers use a different main
         # store.
-        self.main = main_store
+        database = Database(hs)
+        self.main = main_store_class(database, db_conn, hs)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 6843b7e7f8..ec19ae1d9d 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -234,7 +234,7 @@ class Database(object):
         #   to watch it
         self._txn_perf_counters = PerformanceCounters()
 
-        self.database_engine = hs.database_engine
+        self.engine = hs.database_engine
 
         # A set of tables that are not safe to use native upserts in.
         self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
@@ -242,10 +242,10 @@ class Database(object):
         # We add the user_directory_search table to the blacklist on SQLite
         # because the existing search table does not have an index, making it
         # unsafe to use native upserts.
-        if isinstance(self.database_engine, Sqlite3Engine):
+        if isinstance(self.engine, Sqlite3Engine):
             self._unsafe_to_upsert_tables.add("user_directory_search")
 
-        if self.database_engine.can_native_upsert:
+        if self.engine.can_native_upsert:
             # Check ASAP (and then later, every 1s) to see if we have finished
             # background updates of tables that aren't safe to update.
             self._clock.call_later(
@@ -331,7 +331,7 @@ class Database(object):
                 cursor = LoggingTransaction(
                     conn.cursor(),
                     name,
-                    self.database_engine,
+                    self.engine,
                     after_callbacks,
                     exception_callbacks,
                 )
@@ -339,7 +339,7 @@ class Database(object):
                     r = func(cursor, *args, **kwargs)
                     conn.commit()
                     return r
-                except self.database_engine.module.OperationalError as e:
+                except self.engine.module.OperationalError as e:
                     # This can happen if the database disappears mid
                     # transaction.
                     logger.warning(
@@ -353,20 +353,20 @@ class Database(object):
                         i += 1
                         try:
                             conn.rollback()
-                        except self.database_engine.module.Error as e1:
+                        except self.engine.module.Error as e1:
                             logger.warning(
                                 "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
                             )
                         continue
                     raise
-                except self.database_engine.module.DatabaseError as e:
-                    if self.database_engine.is_deadlock(e):
+                except self.engine.module.DatabaseError as e:
+                    if self.engine.is_deadlock(e):
                         logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
                         if i < N:
                             i += 1
                             try:
                                 conn.rollback()
-                            except self.database_engine.module.Error as e1:
+                            except self.engine.module.Error as e1:
                                 logger.warning(
                                     "[TXN EROLL] {%s} %s",
                                     name,
@@ -494,7 +494,7 @@ class Database(object):
                 sql_scheduling_timer.observe(sched_duration_sec)
                 context.add_database_scheduled(sched_duration_sec)
 
-                if self.database_engine.is_connection_closed(conn):
+                if self.engine.is_connection_closed(conn):
                     logger.debug("Reconnecting closed database connection")
                     conn.reconnect()
 
@@ -561,7 +561,7 @@ class Database(object):
         """
         try:
             yield self.runInteraction(desc, self.simple_insert_txn, table, values)
-        except self.database_engine.module.IntegrityError:
+        except self.engine.module.IntegrityError:
             # We have to do or_ignore flag at this layer, since we can't reuse
             # a cursor after we receive an error from the db.
             if not or_ignore:
@@ -660,7 +660,7 @@ class Database(object):
                     lock=lock,
                 )
                 return result
-            except self.database_engine.module.IntegrityError as e:
+            except self.engine.module.IntegrityError as e:
                 attempts += 1
                 if attempts >= 5:
                     # don't retry forever, because things other than races
@@ -692,10 +692,7 @@ class Database(object):
             upserts return True if a new entry was created, False if an existing
             one was updated.
         """
-        if (
-            self.database_engine.can_native_upsert
-            and table not in self._unsafe_to_upsert_tables
-        ):
+        if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
             return self.simple_upsert_txn_native_upsert(
                 txn, table, keyvalues, values, insertion_values=insertion_values
             )
@@ -726,7 +723,7 @@ class Database(object):
         """
         # We need to lock the table :(, unless we're *really* careful
         if lock:
-            self.database_engine.lock_table(txn, table)
+            self.engine.lock_table(txn, table)
 
         def _getwhere(key):
             # If the value we're passing in is None (aka NULL), we need to use
@@ -828,10 +825,7 @@ class Database(object):
         Returns:
             None
         """
-        if (
-            self.database_engine.can_native_upsert
-            and table not in self._unsafe_to_upsert_tables
-        ):
+        if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
             return self.simple_upsert_many_txn_native_upsert(
                 txn, table, key_names, key_values, value_names, value_values
             )
@@ -1301,7 +1295,7 @@ class Database(object):
             "limit": limit,
         }
 
-        sql = self.database_engine.convert_param_style(sql)
+        sql = self.engine.convert_param_style(sql)
 
         txn = db_conn.cursor()
         txn.execute(sql, (int(max_value),))