diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 1953614401..609db40616 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -15,9 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import sys
import time
-from typing import Iterable, Tuple
+from time import monotonic as monotonic_time
+from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from six import iteritems, iterkeys, itervalues
from six.moves import intern, range
@@ -32,24 +32,14 @@ from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.types import Connection, Cursor
from synapse.util.stringutils import exception_to_unicode
-# import a function which will return a monotonic time, in seconds
-try:
- # on python 3, use time.monotonic, since time.clock can go backwards
- from time import monotonic as monotonic_time
-except ImportError:
- # ... but python 2 doesn't have it
- from time import clock as monotonic_time
-
logger = logging.getLogger(__name__)
-try:
- MAX_TXN_ID = sys.maxint - 1
-except AttributeError:
- # python 3 does not have a maximum int value
- MAX_TXN_ID = 2 ** 63 - 1
+# python 3 does not have a maximum int value
+MAX_TXN_ID = 2 ** 63 - 1
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
@@ -77,7 +67,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
def make_pool(
- reactor, db_config: DatabaseConnectionConfig, engine
+ reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> adbapi.ConnectionPool:
"""Get the connection pool for the database.
"""
@@ -90,7 +80,9 @@ def make_pool(
)
-def make_conn(db_config: DatabaseConnectionConfig, engine):
+def make_conn(
+ db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+) -> Connection:
"""Make a new connection to the database and return it.
Returns:
@@ -107,20 +99,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
return db_conn
-class LoggingTransaction(object):
+# The type of entry which goes on our after_callbacks and exception_callbacks lists.
+#
+# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
+# that mypy sees the type but the runtime python doesn't.
+_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
+
+
+class LoggingTransaction:
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method.
Args:
txn: The database transcation object to wrap.
- name (str): The name of this transactions for logging.
- database_engine (Sqlite3Engine|PostgresEngine)
- after_callbacks(list|None): A list that callbacks will be appended to
+ name: The name of this transactions for logging.
+ database_engine
+ after_callbacks: A list that callbacks will be appended to
that have been added by `call_after` which should be run on
successful completion of the transaction. None indicates that no
callbacks should be allowed to be scheduled to run.
- exception_callbacks(list|None): A list that callbacks will be appended
+ exception_callbacks: A list that callbacks will be appended
to that have been added by `call_on_exception` which should be run
if transaction ends with an error. None indicates that no callbacks
should be allowed to be scheduled to run.
@@ -135,46 +134,67 @@ class LoggingTransaction(object):
]
def __init__(
- self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
+ self,
+ txn: Cursor,
+ name: str,
+ database_engine: BaseDatabaseEngine,
+ after_callbacks: Optional[List[_CallbackListEntry]] = None,
+ exception_callbacks: Optional[List[_CallbackListEntry]] = None,
):
- object.__setattr__(self, "txn", txn)
- object.__setattr__(self, "name", name)
- object.__setattr__(self, "database_engine", database_engine)
- object.__setattr__(self, "after_callbacks", after_callbacks)
- object.__setattr__(self, "exception_callbacks", exception_callbacks)
+ self.txn = txn
+ self.name = name
+ self.database_engine = database_engine
+ self.after_callbacks = after_callbacks
+ self.exception_callbacks = exception_callbacks
- def call_after(self, callback, *args, **kwargs):
+ def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
"""
+ # if self.after_callbacks is None, that means that whatever constructed the
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
+ # is not the case.
+ assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs))
- def call_on_exception(self, callback, *args, **kwargs):
+ def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+ # if self.exception_callbacks is None, that means that whatever constructed the
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
+ # is not the case.
+ assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))
- def __getattr__(self, name):
- return getattr(self.txn, name)
+ def fetchall(self) -> List[Tuple]:
+ return self.txn.fetchall()
- def __setattr__(self, name, value):
- setattr(self.txn, name, value)
+ def fetchone(self) -> Tuple:
+ return self.txn.fetchone()
- def __iter__(self):
+ def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__()
+ @property
+ def rowcount(self) -> int:
+ return self.txn.rowcount
+
+ @property
+ def description(self) -> Any:
+ return self.txn.description
+
def execute_batch(self, sql, args):
if isinstance(self.database_engine, PostgresEngine):
- from psycopg2.extras import execute_batch
+ from psycopg2.extras import execute_batch # type: ignore
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
for val in args:
self.execute(sql, val)
- def execute(self, sql, *args):
+ def execute(self, sql: str, *args: Any):
self._do_execute(self.txn.execute, sql, *args)
- def executemany(self, sql, *args):
+ def executemany(self, sql: str, *args: Any):
self._do_execute(self.txn.executemany, sql, *args)
def _make_sql_one_line(self, sql):
@@ -207,6 +227,9 @@ class LoggingTransaction(object):
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs)
+ def close(self):
+ self.txn.close()
+
class PerformanceCounters(object):
def __init__(self):
@@ -251,7 +274,9 @@ class Database(object):
_TXN_ID = 0
- def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
+ def __init__(
+ self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+ ):
self.hs = hs
self._clock = hs.get_clock()
self._database_config = database_config
@@ -259,9 +284,9 @@ class Database(object):
self.updates = BackgroundUpdater(hs, self)
- self._previous_txn_total_time = 0
- self._current_txn_total_time = 0
- self._previous_loop_ts = 0
+ self._previous_txn_total_time = 0.0
+ self._current_txn_total_time = 0.0
+ self._previous_loop_ts = 0.0
# TODO(paul): These can eventually be removed once the metrics code
# is running in mainline, and we have some nice monitoring frontends
@@ -463,23 +488,23 @@ class Database(object):
sql_txn_timer.labels(desc).observe(duration)
@defer.inlineCallbacks
- def runInteraction(self, desc, func, *args, **kwargs):
+ def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
"""Starts a transaction on the database and runs a given function
Arguments:
- desc (str): description of the transaction, for logging and metrics
- func (func): callback function, which will be called with a
+ desc: description of the transaction, for logging and metrics
+ func: callback function, which will be called with a
database transaction (twisted.enterprise.adbapi.Transaction) as
its first argument, followed by `args` and `kwargs`.
- args (list): positional args to pass to `func`
- kwargs (dict): named args to pass to `func`
+ args: positional args to pass to `func`
+ kwargs: named args to pass to `func`
Returns:
Deferred: The result of func
"""
- after_callbacks = []
- exception_callbacks = []
+ after_callbacks = [] # type: List[_CallbackListEntry]
+ exception_callbacks = [] # type: List[_CallbackListEntry]
if LoggingContext.current_context() == LoggingContext.sentinel:
logger.warning("Starting db txn '%s' from sentinel context", desc)
@@ -505,15 +530,15 @@ class Database(object):
return result
@defer.inlineCallbacks
- def runWithConnection(self, func, *args, **kwargs):
+ def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
"""Wraps the .runWithConnection() method on the underlying db_pool.
Arguments:
- func (func): callback function, which will be called with a
+ func: callback function, which will be called with a
database connection (twisted.enterprise.adbapi.Connection) as
its first argument, followed by `args` and `kwargs`.
- args (list): positional args to pass to `func`
- kwargs (dict): named args to pass to `func`
+ args: positional args to pass to `func`
+ kwargs: named args to pass to `func`
Returns:
Deferred: The result of func
@@ -800,7 +825,7 @@ class Database(object):
return False
# We didn't find any existing rows, so insert a new one
- allvalues = {}
+ allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
@@ -829,7 +854,7 @@ class Database(object):
Returns:
None
"""
- allvalues = {}
+ allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
allvalues.update(insertion_values)
@@ -916,7 +941,7 @@ class Database(object):
Returns:
None
"""
- allnames = []
+ allnames = [] # type: List[str]
allnames.extend(key_names)
allnames.extend(value_names)
@@ -1100,7 +1125,7 @@ class Database(object):
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
- results = []
+ results = [] # type: List[Dict[str, Any]]
if not iterable:
return results
@@ -1439,7 +1464,7 @@ class Database(object):
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
where_clause = "WHERE " if filters or keyvalues else ""
- arg_list = []
+ arg_list = [] # type: List[Any]
if filters:
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
arg_list += list(filters.values())
|