diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 709b6f88ac..29702be923 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -16,6 +16,8 @@
from .maria import MariaEngine
from .sqlite3 import Sqlite3Engine
+import importlib
+
SUPPORTED_MODULE = {
"sqlite3": Sqlite3Engine,
@@ -27,7 +29,7 @@ def create_engine(name):
engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class:
- module = __import__(name)
+ module = importlib.import_module(name)
return engine_class(module)
raise RuntimeError(
diff --git a/synapse/storage/engines/maria.py b/synapse/storage/engines/maria.py
index df47763647..7fcb706a60 100644
--- a/synapse/storage/engines/maria.py
+++ b/synapse/storage/engines/maria.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.storage import prepare_database
import types
@@ -28,3 +29,14 @@ class MariaEngine(object):
if isinstance(param, types.BufferType):
return str(param)
return param
+
+ def on_new_connection(self, db_conn):
+ pass
+
+ def prepare_database(self, db_conn):
+ cur = db_conn.cursor()
+ cur.execute(
+ "ALTER DATABASE CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci"
+ )
+ db_conn.commit()
+ prepare_database(db_conn, self)
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 639cdea41d..e802b5d5fd 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.storage import prepare_database, prepare_sqlite3_database
+
class Sqlite3Engine(object):
def __init__(self, database_module):
@@ -23,3 +25,10 @@ class Sqlite3Engine(object):
def encode_parameter(self, param):
return param
+
+ def on_new_connection(self, db_conn):
+ self.prepare_database(db_conn)
+
+ def prepare_database(self, db_conn):
+ prepare_sqlite3_database(db_conn)
+ prepare_database(db_conn, self)
|