summary refs log tree commit diff
path: root/synapse/storage/engines
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/engines')
-rw-r--r--synapse/storage/engines/_base.py23
-rw-r--r--synapse/storage/engines/postgres.py12
-rw-r--r--synapse/storage/engines/sqlite.py21
3 files changed, 52 insertions, 4 deletions
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 0d16a419a4..70e594a68f 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -32,9 +32,10 @@ class IncorrectDatabaseSetup(RuntimeError):
 
 
 ConnectionType = TypeVar("ConnectionType", bound=Connection)
+CursorType = TypeVar("CursorType", bound=Cursor)
 
 
-class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
+class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCMeta):
     def __init__(self, module: DBAPI2Module, config: Mapping[str, Any]):
         self.module = module
 
@@ -64,7 +65,7 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
-    def check_new_database(self, txn: Cursor) -> None:
+    def check_new_database(self, txn: CursorType) -> None:
         """Gets called when setting up a brand new database. This allows us to
         apply stricter checks on new databases versus existing database.
         """
@@ -124,3 +125,21 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
         Note: This has no effect on SQLite3, as transactions are SERIALIZABLE by default.
         """
         ...
+
+    @staticmethod
+    @abc.abstractmethod
+    def executescript(cursor: CursorType, script: str) -> None:
+        """Execute a chunk of SQL containing multiple semicolon-delimited statements.
+
+        This is not provided by DBAPI2, and so needs engine-specific support.
+        """
+        ...
+
+    @classmethod
+    def execute_script_file(cls, cursor: CursorType, filepath: str) -> None:
+        """Execute a file containing multiple semicolon-delimited SQL statements.
+
+        This is not provided by DBAPI2, and so needs engine-specific support.
+        """
+        with open(filepath, "rt") as f:
+            cls.executescript(cursor, f.read())
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 7f7d006ac2..d8c0f64d9a 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -31,7 +31,9 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class PostgresEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
+class PostgresEngine(
+    BaseDatabaseEngine[psycopg2.extensions.connection, psycopg2.extensions.cursor]
+):
     def __init__(self, database_config: Mapping[str, Any]):
         super().__init__(psycopg2, database_config)
         psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
@@ -212,3 +214,11 @@ class PostgresEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
         else:
             isolation_level = self.isolation_level_map[isolation_level]
         return conn.set_isolation_level(isolation_level)
+
+    @staticmethod
+    def executescript(cursor: psycopg2.extensions.cursor, script: str) -> None:
+        """Execute a chunk of SQL containing multiple semicolon-delimited statements.
+
+        Psycopg2 seems happy to do this in DBAPI2's `execute()` function.
+        """
+        cursor.execute(script)
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 095ae0a096..faa574dbfd 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
     from synapse.storage.database import LoggingDatabaseConnection
 
 
-class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
+class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
     def __init__(self, database_config: Mapping[str, Any]):
         super().__init__(sqlite3, database_config)
 
@@ -120,6 +120,25 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
         # All transactions are SERIALIZABLE by default in sqlite
         pass
 
+    @staticmethod
+    def executescript(cursor: sqlite3.Cursor, script: str) -> None:
+        """Execute a chunk of SQL containing multiple semicolon-delimited statements.
+
+        Python's built-in SQLite driver does not allow you to do this with DBAPI2's
+        `execute`:
+
+        > execute() will only execute a single SQL statement. If you try to execute more
+        > than one statement with it, it will raise a Warning. Use executescript() if
+        > you want to execute multiple SQL statements with one call.
+
+        Though the docs for `executescript` warn:
+
+        > If there is a pending transaction, an implicit COMMIT statement is executed
+        > first. No other implicit transaction control is performed; any transaction
+        > control must be added to sql_script.
+        """
+        cursor.executescript(script)
+
 
 # Following functions taken from: https://github.com/coleifer/peewee