diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 971ff82693..70e594a68f 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -32,9 +32,10 @@ class IncorrectDatabaseSetup(RuntimeError):
ConnectionType = TypeVar("ConnectionType", bound=Connection)
+CursorType = TypeVar("CursorType", bound=Cursor)
-class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
+class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCMeta):
def __init__(self, module: DBAPI2Module, config: Mapping[str, Any]):
self.module = module
@@ -45,14 +46,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
- def can_native_upsert(self) -> bool:
- """
- Do we support native UPSERTs?
- """
- ...
-
- @property
- @abc.abstractmethod
def supports_using_any_list(self) -> bool:
"""
Do we support using `a = ANY(?)` and passing a list
@@ -72,7 +65,7 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
...
@abc.abstractmethod
- def check_new_database(self, txn: Cursor) -> None:
+ def check_new_database(self, txn: CursorType) -> None:
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
@@ -132,3 +125,21 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
Note: This has no effect on SQLite3, as transactions are SERIALIZABLE by default.
"""
...
+
+ @staticmethod
+ @abc.abstractmethod
+ def executescript(cursor: CursorType, script: str) -> None:
+ """Execute a chunk of SQL containing multiple semicolon-delimited statements.
+
+ This is not provided by DBAPI2, and so needs engine-specific support.
+ """
+ ...
+
+ @classmethod
+ def execute_script_file(cls, cursor: CursorType, filepath: str) -> None:
+ """Execute a file containing multiple semicolon-delimited SQL statements.
+
+ This is not provided by DBAPI2, and so needs engine-specific support.
+ """
+ with open(filepath, "rt") as f:
+ cls.executescript(cursor, f.read())
|