diff options
Diffstat (limited to 'synapse/storage/engines/sqlite.py')
-rw-r--r-- | synapse/storage/engines/sqlite.py | 72 |
1 files changed, 36 insertions, 36 deletions
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 6c19e55999..621f2c5efe 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import platform +import sqlite3 import struct import threading -import typing -from typing import Optional +from typing import TYPE_CHECKING, Any, List, Mapping, Optional from synapse.storage.engines import BaseDatabaseEngine -from synapse.storage.types import Connection +from synapse.storage.types import Cursor -if typing.TYPE_CHECKING: - import sqlite3 # noqa: F401 +if TYPE_CHECKING: + from synapse.storage.database import LoggingDatabaseConnection -class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): - def __init__(self, database_module, database_config): - super().__init__(database_module, database_config) +class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]): + def __init__(self, database_config: Mapping[str, Any]): + super().__init__(sqlite3, database_config) database = database_config.get("args", {}).get("database") self._is_in_memory = database in ( @@ -37,7 +37,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): if platform.python_implementation() == "PyPy": # pypy's sqlite3 module doesn't handle bytearrays, convert them # back to bytes. - database_module.register_adapter(bytearray, lambda array: bytes(array)) + sqlite3.register_adapter(bytearray, lambda array: bytes(array)) # The current max state_group, or None if we haven't looked # in the DB yet. @@ -49,41 +49,43 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): return True @property - def can_native_upsert(self): + def can_native_upsert(self) -> bool: """ Do we support native UPSERTs? This requires SQLite3 3.24+, plus some more work we haven't done yet to tell what was inserted vs updated. """ - return self.module.sqlite_version_info >= (3, 24, 0) + return sqlite3.sqlite_version_info >= (3, 24, 0) @property - def supports_using_any_list(self): + def supports_using_any_list(self) -> bool: """Do we support using `a = ANY(?)` and passing a list""" return False @property def supports_returning(self) -> bool: """Do we support the `RETURNING` clause in insert/update/delete?""" - return self.module.sqlite_version_info >= (3, 35, 0) + return sqlite3.sqlite_version_info >= (3, 35, 0) - def check_database(self, db_conn, allow_outdated_version: bool = False): + def check_database( + self, db_conn: sqlite3.Connection, allow_outdated_version: bool = False + ) -> None: if not allow_outdated_version: - version = self.module.sqlite_version_info + version = sqlite3.sqlite_version_info # Synapse is untested against older SQLite versions, and we don't want # to let users upgrade to a version of Synapse with broken support for their # sqlite version, because it risks leaving them with a half-upgraded db. if version < (3, 22, 0): raise RuntimeError("Synapse requires sqlite 3.22 or above.") - def check_new_database(self, txn): + def check_new_database(self, txn: Cursor) -> None: """Gets called when setting up a brand new database. This allows us to apply stricter checks on new databases versus existing database. """ - def convert_param_style(self, sql): + def convert_param_style(self, sql: str) -> str: return sql - def on_new_connection(self, db_conn): + def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: # We need to import here to avoid an import loop. from synapse.storage.prepare_database import prepare_database @@ -97,48 +99,46 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): db_conn.execute("PRAGMA foreign_keys = ON;") db_conn.commit() - def is_deadlock(self, error): + def is_deadlock(self, error: Exception) -> bool: return False - def is_connection_closed(self, conn): + def is_connection_closed(self, conn: sqlite3.Connection) -> bool: return False - def lock_table(self, txn, table): + def lock_table(self, txn: Cursor, table: str) -> None: return @property - def server_version(self): - """Gets a string giving the server version. For example: '3.22.0' + def server_version(self) -> str: + """Gets a string giving the server version. For example: '3.22.0'.""" + return "%i.%i.%i" % sqlite3.sqlite_version_info - Returns: - string - """ - return "%i.%i.%i" % self.module.sqlite_version_info - - def in_transaction(self, conn: Connection) -> bool: - return conn.in_transaction # type: ignore + def in_transaction(self, conn: sqlite3.Connection) -> bool: + return conn.in_transaction - def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + def attempt_to_set_autocommit( + self, conn: sqlite3.Connection, autocommit: bool + ) -> None: # Twisted doesn't let us set attributes on the connections, so we can't # set the connection to autocommit mode. pass def attempt_to_set_isolation_level( - self, conn: Connection, isolation_level: Optional[int] - ): - # All transactions are SERIALIZABLE by default in sqllite + self, conn: sqlite3.Connection, isolation_level: Optional[int] + ) -> None: + # All transactions are SERIALIZABLE by default in sqlite pass # Following functions taken from: https://github.com/coleifer/peewee -def _parse_match_info(buf): +def _parse_match_info(buf: bytes) -> List[int]: bufsize = len(buf) return [struct.unpack("@I", buf[i : i + 4])[0] for i in range(0, bufsize, 4)] -def _rank(raw_match_info): +def _rank(raw_match_info: bytes) -> float: """Handle match_info called w/default args 'pcx' - based on the example rank function http://sqlite.org/fts3.html#appendix_a """ |