diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b112ff3df2..ed8a9bffb1 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -16,6 +16,7 @@
# limitations under the License.
import logging
import time
+from sys import intern
from time import monotonic as monotonic_time
from typing import (
Any,
@@ -27,15 +28,14 @@ from typing import (
Optional,
Tuple,
TypeVar,
+ cast,
+ overload,
)
-from six import iteritems, iterkeys, itervalues
-from six.moves import intern, range
-
from prometheus_client import Histogram
+from typing_extensions import Literal
from twisted.enterprise import adbapi
-from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -51,11 +51,11 @@ from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3E
from synapse.storage.types import Connection, Cursor
from synapse.types import Collection
-logger = logging.getLogger(__name__)
-
# python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1
+logger = logging.getLogger(__name__)
+
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
perf_logger = logging.getLogger("synapse.storage.TIME")
@@ -127,7 +127,7 @@ class LoggingTransaction:
method.
Args:
- txn: The database transcation object to wrap.
+ txn: The database transaction object to wrap.
name: The name of this transactions for logging.
database_engine
after_callbacks: A list that callbacks will be appended to
@@ -162,7 +162,7 @@ class LoggingTransaction:
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks
- def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
+ def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
@@ -173,7 +173,9 @@ class LoggingTransaction:
assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs))
- def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+ def call_on_exception(
+ self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
+ ):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
@@ -197,7 +199,7 @@ class LoggingTransaction:
def description(self) -> Any:
return self.txn.description
- def execute_batch(self, sql, args):
+ def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
@@ -206,17 +208,17 @@ class LoggingTransaction:
for val in args:
self.execute(sql, val)
- def execute(self, sql: str, *args: Any):
+ def execute(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.execute, sql, *args)
- def executemany(self, sql: str, *args: Any):
+ def executemany(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.executemany, sql, *args)
def _make_sql_one_line(self, sql: str) -> str:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(line.strip() for line in sql.splitlines() if line.strip())
- def _do_execute(self, func, sql, *args):
+ def _do_execute(self, func, sql: str, *args: Any) -> None:
sql = self._make_sql_one_line(sql)
# TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -235,31 +237,31 @@ class LoggingTransaction:
try:
return func(sql, *args)
except Exception as e:
- logger.debug("[SQL FAIL] {%s} %s", self.name, e)
+ sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
finally:
secs = time.time() - start
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs)
- def close(self):
+ def close(self) -> None:
self.txn.close()
-class PerformanceCounters(object):
+class PerformanceCounters:
def __init__(self):
self.current_counters = {}
self.previous_counters = {}
- def update(self, key, duration_secs):
+ def update(self, key: str, duration_secs: float) -> None:
count, cum_time = self.current_counters.get(key, (0, 0))
count += 1
cum_time += duration_secs
self.current_counters[key] = (count, cum_time)
- def interval(self, interval_duration_secs, limit=3):
+ def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
counters = []
- for name, (count, cum_time) in iteritems(self.current_counters):
+ for name, (count, cum_time) in self.current_counters.items():
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append(
(
@@ -281,7 +283,10 @@ class PerformanceCounters(object):
return top_n_counters
-class Database(object):
+R = TypeVar("R")
+
+
+class DatabasePool:
"""Wraps a single physical database and connection pool.
A single database may be used by multiple data stores.
@@ -329,13 +334,12 @@ class Database(object):
self._check_safe_to_upsert,
)
- def is_running(self):
+ def is_running(self) -> bool:
"""Is the database pool currently running
"""
return self._db_pool.running
- @defer.inlineCallbacks
- def _check_safe_to_upsert(self):
+ async def _check_safe_to_upsert(self) -> None:
"""
Is it safe to use native UPSERT?
@@ -344,7 +348,7 @@ class Database(object):
If the background updates have not completed, wait 15 sec and check again.
"""
- updates = yield self.simple_select_list(
+ updates = await self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
@@ -366,7 +370,7 @@ class Database(object):
self._check_safe_to_upsert,
)
- def start_profiling(self):
+ def start_profiling(self) -> None:
self._previous_loop_ts = monotonic_time()
def loop():
@@ -390,8 +394,15 @@ class Database(object):
self._clock.looping_call(loop, 10000)
def new_transaction(
- self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
- ):
+ self,
+ conn: Connection,
+ desc: str,
+ after_callbacks: List[_CallbackListEntry],
+ exception_callbacks: List[_CallbackListEntry],
+ func: "Callable[..., R]",
+ *args: Any,
+ **kwargs: Any
+ ) -> R:
start = monotonic_time()
txn_id = self._TXN_ID
@@ -421,7 +432,7 @@ class Database(object):
except self.engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
- logger.warning(
+ transaction_logger.warning(
"[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
)
if i < N:
@@ -429,18 +440,20 @@ class Database(object):
try:
conn.rollback()
except self.engine.module.Error as e1:
- logger.warning("[TXN EROLL] {%s} %s", name, e1)
+ transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
continue
raise
except self.engine.module.DatabaseError as e:
if self.engine.is_deadlock(e):
- logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+ transaction_logger.warning(
+ "[TXN DEADLOCK] {%s} %d/%d", name, i, N
+ )
if i < N:
i += 1
try:
conn.rollback()
except self.engine.module.Error as e1:
- logger.warning(
+ transaction_logger.warning(
"[TXN EROLL] {%s} %s", name, e1,
)
continue
@@ -480,7 +493,7 @@ class Database(object):
# [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
cursor.close()
except Exception as e:
- logger.debug("[TXN FAIL] {%s} %s", name, e)
+ transaction_logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = monotonic_time()
@@ -494,8 +507,9 @@ class Database(object):
self._txn_perf_counters.update(desc, duration)
sql_txn_timer.labels(desc).observe(duration)
- @defer.inlineCallbacks
- def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
+ async def runInteraction(
+ self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
+ ) -> R:
"""Starts a transaction on the database and runs a given function
Arguments:
@@ -508,7 +522,7 @@ class Database(object):
kwargs: named args to pass to `func`
Returns:
- Deferred: The result of func
+ The result of func
"""
after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] # type: List[_CallbackListEntry]
@@ -517,7 +531,7 @@ class Database(object):
logger.warning("Starting db txn '%s' from sentinel context", desc)
try:
- result = yield self.runWithConnection(
+ result = await self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
@@ -534,10 +548,11 @@ class Database(object):
after_callback(*after_args, **after_kwargs)
raise
- return result
+ return cast(R, result)
- @defer.inlineCallbacks
- def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
+ async def runWithConnection(
+ self, func: "Callable[..., R]", *args: Any, **kwargs: Any
+ ) -> R:
"""Wraps the .runWithConnection() method on the underlying db_pool.
Arguments:
@@ -548,7 +563,7 @@ class Database(object):
kwargs: named args to pass to `func`
Returns:
- Deferred: The result of func
+ The result of func
"""
parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
if not parent_context:
@@ -571,18 +586,16 @@ class Database(object):
return func(conn, *args, **kwargs)
- result = yield make_deferred_yieldable(
+ return await make_deferred_yieldable(
self._db_pool.runWithConnection(inner_func, *args, **kwargs)
)
- return result
-
@staticmethod
- def cursor_to_dict(cursor):
+ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
"""Converts a SQL cursor into an list of dicts.
Args:
- cursor : The DBAPI cursor which has executed a query.
+ cursor: The DBAPI cursor which has executed a query.
Returns:
A list of dicts where the key is the column header.
"""
@@ -590,10 +603,29 @@ class Database(object):
results = [dict(zip(col_headers, row)) for row in cursor]
return results
- def execute(self, desc, decoder, query, *args):
+ @overload
+ async def execute(
+ self, desc: str, decoder: Literal[None], query: str, *args: Any
+ ) -> List[Tuple[Any, ...]]:
+ ...
+
+ @overload
+ async def execute(
+ self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
+ ) -> R:
+ ...
+
+ async def execute(
+ self,
+ desc: str,
+ decoder: Optional[Callable[[Cursor], R]],
+ query: str,
+ *args: Any
+ ) -> R:
"""Runs a single query for a result set.
Args:
+ desc: description of the transaction, for logging and metrics
decoder - The function which can resolve the cursor results to
something meaningful.
query - The query string to execute
@@ -609,29 +641,33 @@ class Database(object):
else:
return txn.fetchall()
- return self.runInteraction(desc, interaction)
+ return await self.runInteraction(desc, interaction)
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
- @defer.inlineCallbacks
- def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+ async def simple_insert(
+ self,
+ table: str,
+ values: Dict[str, Any],
+ or_ignore: bool = False,
+ desc: str = "simple_insert",
+ ) -> bool:
"""Executes an INSERT query on the named table.
Args:
- table : string giving the table name
- values : dict of new column names and values for them
- or_ignore : bool stating whether an exception should be raised
+ table: string giving the table name
+ values: dict of new column names and values for them
+ or_ignore: bool stating whether an exception should be raised
when a conflicting row already exists. If True, False will be
returned by the function instead
- desc : string giving a description of the transaction
+ desc: description of the transaction, for logging and metrics
Returns:
- bool: Whether the row was inserted or not. Only useful when
- `or_ignore` is True
+ Whether the row was inserted or not. Only useful when `or_ignore` is True
"""
try:
- yield self.runInteraction(desc, self.simple_insert_txn, table, values)
+ await self.runInteraction(desc, self.simple_insert_txn, table, values)
except self.engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
@@ -641,7 +677,9 @@ class Database(object):
return True
@staticmethod
- def simple_insert_txn(txn, table, values):
+ def simple_insert_txn(
+ txn: LoggingTransaction, table: str, values: Dict[str, Any]
+ ) -> None:
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -652,11 +690,29 @@ class Database(object):
txn.execute(sql, vals)
- def simple_insert_many(self, table, values, desc):
- return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
+ async def simple_insert_many(
+ self, table: str, values: List[Dict[str, Any]], desc: str
+ ) -> None:
+ """Executes an INSERT query on the named table.
+
+ Args:
+ table: string giving the table name
+ values: dict of new column names and values for them
+ desc: description of the transaction, for logging and metrics
+ """
+ await self.runInteraction(desc, self.simple_insert_many_txn, table, values)
@staticmethod
- def simple_insert_many_txn(txn, table, values):
+ def simple_insert_many_txn(
+ txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
+ ) -> None:
+ """Executes an INSERT query on the named table.
+
+ Args:
+ txn: The transaction to use.
+ table: string giving the table name
+ values: dict of new column names and values for them
+ """
if not values:
return
@@ -684,16 +740,15 @@ class Database(object):
txn.executemany(sql, vals)
- @defer.inlineCallbacks
- def simple_upsert(
+ async def simple_upsert(
self,
- table,
- keyvalues,
- values,
- insertion_values={},
- desc="simple_upsert",
- lock=True,
- ):
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ desc: str = "simple_upsert",
+ lock: bool = True,
+ ) -> Optional[bool]:
"""
`lock` should generally be set to True (the default), but can be set
@@ -707,21 +762,20 @@ class Database(object):
this table.
Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key columns and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
+ table: The table to upsert into
+ keyvalues: The unique key columns and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
+ desc: description of the transaction, for logging and metrics
+ lock: True to lock the table when doing the upsert.
Returns:
- Deferred(None or bool): Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
+ Native upserts always return None. Emulated upserts return True if a
+ new entry was created, False if an existing one was updated.
"""
attempts = 0
while True:
try:
- result = yield self.runInteraction(
+ return await self.runInteraction(
desc,
self.simple_upsert_txn,
table,
@@ -730,7 +784,6 @@ class Database(object):
insertion_values,
lock=lock,
)
- return result
except self.engine.module.IntegrityError as e:
attempts += 1
if attempts >= 5:
@@ -744,29 +797,34 @@ class Database(object):
)
def simple_upsert_txn(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ lock: bool = True,
+ ) -> Optional[bool]:
"""
Pick the UPSERT method which works best on the platform. Either the
native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
Args:
txn: The transaction to use.
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
+ table: The table to upsert into
+ keyvalues: The unique key tables and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
+ lock: True to lock the table when doing the upsert.
Returns:
- None or bool: Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
+ Native upserts always return None. Emulated upserts return True if a
+ new entry was created, False if an existing one was updated.
"""
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
- return self.simple_upsert_txn_native_upsert(
+ self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
)
+ return None
else:
return self.simple_upsert_txn_emulated(
txn,
@@ -778,18 +836,23 @@ class Database(object):
)
def simple_upsert_txn_emulated(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ lock: bool = True,
+ ) -> bool:
"""
Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
+ table: The table to upsert into
+ keyvalues: The unique key tables and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
+ lock: True to lock the table when doing the upsert.
Returns:
- bool: Return True if a new entry was created, False if an existing
+ Returns True if a new entry was created, False if an existing
one was updated.
"""
# We need to lock the table :(, unless we're *really* careful
@@ -847,19 +910,21 @@ class Database(object):
return True
def simple_upsert_txn_native_upsert(
- self, txn, table, keyvalues, values, insertion_values={}
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ ) -> None:
"""
Use the native UPSERT functionality in recent PostgreSQL versions.
Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- Returns:
- None
+ table: The table to upsert into
+ keyvalues: The unique key tables and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
"""
allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
@@ -989,41 +1054,93 @@ class Database(object):
return txn.execute_batch(sql, args)
- def simple_select_one(
- self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
- ):
+ @overload
+ async def simple_select_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: Literal[False] = False,
+ desc: str = "simple_select_one",
+ ) -> Dict[str, Any]:
+ ...
+
+ @overload
+ async def simple_select_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: Literal[True] = True,
+ desc: str = "simple_select_one",
+ ) -> Optional[Dict[str, Any]]:
+ ...
+
+ async def simple_select_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: bool = False,
+ desc: str = "simple_select_one",
+ ) -> Optional[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcols : list of strings giving the names of the columns to return
-
- allow_none : If true, return None instead of failing if the SELECT
- statement returns no rows
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ retcols: list of strings giving the names of the columns to return
+ allow_none: If true, return None instead of failing if the SELECT
+ statement returns no rows
+ desc: description of the transaction, for logging and metrics
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)
- def simple_select_one_onecol(
+ @overload
+ async def simple_select_one_onecol(
self,
- table,
- keyvalues,
- retcol,
- allow_none=False,
- desc="simple_select_one_onecol",
- ):
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: str,
+ allow_none: Literal[False] = False,
+ desc: str = "simple_select_one_onecol",
+ ) -> Any:
+ ...
+
+ @overload
+ async def simple_select_one_onecol(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: str,
+ allow_none: Literal[True] = True,
+ desc: str = "simple_select_one_onecol",
+ ) -> Optional[Any]:
+ ...
+
+ async def simple_select_one_onecol(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: str,
+ allow_none: bool = False,
+ desc: str = "simple_select_one_onecol",
+ ) -> Optional[Any]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcol : string giving the name of the column to return
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ retcol: string giving the name of the column to return
+ allow_none: If true, return None instead of failing if the SELECT
+ statement returns no rows
+ desc: description of the transaction, for logging and metrics
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc,
self.simple_select_one_onecol_txn,
table,
@@ -1032,10 +1149,39 @@ class Database(object):
allow_none=allow_none,
)
+ @overload
@classmethod
def simple_select_one_onecol_txn(
- cls, txn, table, keyvalues, retcol, allow_none=False
- ):
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: str,
+ allow_none: Literal[False] = False,
+ ) -> Any:
+ ...
+
+ @overload
+ @classmethod
+ def simple_select_one_onecol_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: str,
+ allow_none: Literal[True] = True,
+ ) -> Optional[Any]:
+ ...
+
+ @classmethod
+ def simple_select_one_onecol_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: str,
+ allow_none: bool = False,
+ ) -> Optional[Any]:
ret = cls.simple_select_onecol_txn(
txn, table=table, keyvalues=keyvalues, retcol=retcol
)
@@ -1049,64 +1195,85 @@ class Database(object):
raise StoreError(404, "No row found")
@staticmethod
- def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+ def simple_select_onecol_txn(
+ txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str,
+ ) -> List[Any]:
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
if keyvalues:
- sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
txn.execute(sql, list(keyvalues.values()))
else:
txn.execute(sql)
return [r[0] for r in txn]
- def simple_select_onecol(
- self, table, keyvalues, retcol, desc="simple_select_onecol"
- ):
+ async def simple_select_onecol(
+ self,
+ table: str,
+ keyvalues: Optional[Dict[str, Any]],
+ retcol: str,
+ desc: str = "simple_select_onecol",
+ ) -> List[Any]:
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
Args:
- table (str): table name
- keyvalues (dict|None): column names and values to select the rows with
- retcol (str): column whos value we wish to retrieve.
+ table: table name
+ keyvalues: column names and values to select the rows with
+ retcol: column whos value we wish to retrieve.
+ desc: description of the transaction, for logging and metrics
Returns:
- Deferred: Results in a list
+ Results in a list
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc, self.simple_select_onecol_txn, table, keyvalues, retcol
)
- def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+ async def simple_select_list(
+ self,
+ table: str,
+ keyvalues: Optional[Dict[str, Any]],
+ retcols: Iterable[str],
+ desc: str = "simple_select_list",
+ ) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- table (str): the table name
- keyvalues (dict[str, Any] | None):
+ table: the table name
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
+ retcols: the names of the columns to return
+ desc: description of the transaction, for logging and metrics
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ A list of dictionaries.
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc, self.simple_select_list_txn, table, keyvalues, retcols
)
@classmethod
- def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+ def simple_select_list_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Optional[Dict[str, Any]],
+ retcols: Iterable[str],
+ ) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- txn : Transaction object
- table (str): the table name
- keyvalues (dict[str, T] | None):
+ txn: Transaction object
+ table: the table name
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
+ retcols: the names of the columns to return
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1121,28 +1288,29 @@ class Database(object):
return cls.cursor_to_dict(txn)
- @defer.inlineCallbacks
- def simple_select_many_batch(
+ async def simple_select_many_batch(
self,
- table,
- column,
- iterable,
- retcols,
- keyvalues={},
- desc="simple_select_many_batch",
- batch_size=100,
- ):
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ retcols: Iterable[str],
+ keyvalues: Dict[str, Any] = {},
+ desc: str = "simple_select_many_batch",
+ batch_size: int = 100,
+ ) -> List[Any]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
- Filters rows by if value of `column` is in `iterable`.
+ Filters rows by whether the value of `column` is in `iterable`.
Args:
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ retcols: list of strings giving the names of the columns to return
+ keyvalues: dict of column names and values to select the rows with
+ desc: description of the transaction, for logging and metrics
+ batch_size: the number of rows for each select query
"""
results = [] # type: List[Dict[str, Any]]
@@ -1156,7 +1324,7 @@ class Database(object):
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
- rows = yield self.runInteraction(
+ rows = await self.runInteraction(
desc,
self.simple_select_many_txn,
table,
@@ -1171,19 +1339,27 @@ class Database(object):
return results
@classmethod
- def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+ def simple_select_many_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ ) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
- Filters rows by if value of `column` is in `iterable`.
+ Filters rows by whether the value of `column` is in `iterable`.
Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
+ txn: Transaction object
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ keyvalues: dict of column names and values to select the rows with
+ retcols: list of strings giving the names of the columns to return
"""
if not iterable:
return []
@@ -1191,7 +1367,7 @@ class Database(object):
clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
clauses = [clause]
- for key, value in iteritems(keyvalues):
+ for key, value in keyvalues.items():
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -1204,15 +1380,26 @@ class Database(object):
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
- def simple_update(self, table, keyvalues, updatevalues, desc):
- return self.runInteraction(
+ async def simple_update(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ desc: str,
+ ) -> int:
+ return await self.runInteraction(
desc, self.simple_update_txn, table, keyvalues, updatevalues
)
@staticmethod
- def simple_update_txn(txn, table, keyvalues, updatevalues):
+ def simple_update_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ ) -> int:
if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
else:
where = ""
@@ -1226,32 +1413,34 @@ class Database(object):
return txn.rowcount
- def simple_update_one(
- self, table, keyvalues, updatevalues, desc="simple_update_one"
- ):
+ async def simple_update_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ desc: str = "simple_update_one",
+ ) -> None:
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- updatevalues : dict giving column names and values to update
- retcols : optional list of column names to return
-
- If present, retcols gives a list of column names on which to perform
- a SELECT statement *before* performing the UPDATE statement. The values
- of these will be returned in a dict.
-
- These are performed within the same transaction, allowing an atomic
- get-and-set. This can be used to implement compare-and-set by putting
- the update column in the 'keyvalues' dict as well.
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ updatevalues: dict giving column names and values to update
+ desc: description of the transaction, for logging and metrics
"""
- return self.runInteraction(
+ await self.runInteraction(
desc, self.simple_update_one_txn, table, keyvalues, updatevalues
)
@classmethod
- def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+ def simple_update_one_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ ) -> None:
rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
if rowcount == 0:
@@ -1259,8 +1448,18 @@ class Database(object):
if rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
+ # Ideally we could use the overload decorator here to specify that the
+ # return type is only optional if allow_none is True, but this does not work
+ # when you call a static method from an instance.
+ # See https://github.com/python/mypy/issues/7781
@staticmethod
- def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+ def simple_select_one_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: bool = False,
+ ) -> Optional[Dict[str, Any]]:
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
@@ -1279,24 +1478,29 @@ class Database(object):
return dict(zip(retcols, row))
- def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+ async def simple_delete_one(
+ self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
+ ) -> None:
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ desc: description of the transaction, for logging and metrics
"""
- return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
+ await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
@staticmethod
- def simple_delete_one_txn(txn, table, keyvalues):
+ def simple_delete_one_txn(
+ txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+ ) -> None:
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
@@ -1309,11 +1513,38 @@ class Database(object):
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
- def simple_delete(self, table, keyvalues, desc):
- return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
+ async def simple_delete(
+ self, table: str, keyvalues: Dict[str, Any], desc: str
+ ) -> int:
+ """Executes a DELETE query on the named table.
+
+ Filters rows by the key-value pairs.
+
+ Args:
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ desc: description of the transaction, for logging and metrics
+
+ Returns:
+ The number of deleted rows.
+ """
+ return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
@staticmethod
- def simple_delete_txn(txn, table, keyvalues):
+ def simple_delete_txn(
+ txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+ ) -> int:
+ """Executes a DELETE query on the named table.
+
+ Filters rows by the key-value pairs.
+
+ Args:
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+
+ Returns:
+ The number of deleted rows.
+ """
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k,) for k in keyvalues),
@@ -1322,26 +1553,53 @@ class Database(object):
txn.execute(sql, list(keyvalues.values()))
return txn.rowcount
- def simple_delete_many(self, table, column, iterable, keyvalues, desc):
- return self.runInteraction(
+ async def simple_delete_many(
+ self,
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ keyvalues: Dict[str, Any],
+ desc: str,
+ ) -> int:
+ """Executes a DELETE query on the named table.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ keyvalues: dict of column names and values to select the rows with
+ desc: description of the transaction, for logging and metrics
+
+ Returns:
+ Number rows deleted
+ """
+ return await self.runInteraction(
desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
)
@staticmethod
- def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+ def simple_delete_many_txn(
+ txn: LoggingTransaction,
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ keyvalues: Dict[str, Any],
+ ) -> int:
"""Executes a DELETE query on the named table.
Filters rows by if value of `column` is in `iterable`.
Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
+ txn: Transaction object
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ keyvalues: dict of column names and values to select the rows with
Returns:
- int: Number rows deleted
+ Number rows deleted
"""
if not iterable:
return 0
@@ -1351,7 +1609,7 @@ class Database(object):
clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
clauses = [clause]
- for key, value in iteritems(keyvalues):
+ for key, value in keyvalues.items():
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -1362,8 +1620,14 @@ class Database(object):
return txn.rowcount
def get_cache_dict(
- self, db_conn, table, entity_column, stream_column, max_value, limit=100000
- ):
+ self,
+ db_conn: Connection,
+ table: str,
+ entity_column: str,
+ stream_column: str,
+ max_value: int,
+ limit: int = 100000,
+ ) -> Tuple[Dict[Any, int], int]:
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
@@ -1388,71 +1652,25 @@ class Database(object):
txn.close()
if cache:
- min_val = min(itervalues(cache))
+ min_val = min(cache.values())
else:
min_val = max_value
return cache, min_val
- def simple_select_list_paginate(
- self,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=None,
- keyvalues=None,
- order_direction="ASC",
- desc="simple_select_list_paginate",
- ):
- """
- Executes a SELECT query on the named table with start and limit,
- of row numbers, which may return zero or number of rows from start to limit,
- returning the result as a list of dicts.
-
- Args:
- table (str): the table name
- filters (dict[str, T] | None):
- column names and values to filter the rows with, or None to not
- apply a WHERE ? LIKE ? clause.
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- return self.runInteraction(
- desc,
- self.simple_select_list_paginate_txn,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=filters,
- keyvalues=keyvalues,
- order_direction=order_direction,
- )
-
@classmethod
def simple_select_list_paginate_txn(
cls,
- txn,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=None,
- keyvalues=None,
- order_direction="ASC",
- ):
+ txn: LoggingTransaction,
+ table: str,
+ orderby: str,
+ start: int,
+ limit: int,
+ retcols: Iterable[str],
+ filters: Optional[Dict[str, Any]] = None,
+ keyvalues: Optional[Dict[str, Any]] = None,
+ order_direction: str = "ASC",
+ ) -> List[Dict[str, Any]]:
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
@@ -1463,21 +1681,22 @@ class Database(object):
using 'AND'.
Args:
- txn : Transaction object
- table (str): the table name
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- filters (dict[str, T] | None):
+ txn: Transaction object
+ table: the table name
+ orderby: Column to order the results by.
+ start: Index to begin the query at.
+ limit: Number of results to return.
+ retcols: the names of the columns to return
+ filters:
column names and values to filter the rows with, or None to not
apply a WHERE ? LIKE ? clause.
- keyvalues (dict[str, T] | None):
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+ order_direction: Whether the results should be ordered "ASC" or "DESC".
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ The result as a list of dictionaries.
"""
if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
@@ -1503,51 +1722,65 @@ class Database(object):
return cls.cursor_to_dict(txn)
- def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+ async def simple_search_list(
+ self,
+ table: str,
+ term: Optional[str],
+ col: str,
+ retcols: Iterable[str],
+ desc="simple_search_list",
+ ) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
+ table: the table name
+ term: term for searching the table matched to a column.
+ col: column to query term should be matched to
+ retcols: the names of the columns to return
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
+ A list of dictionaries or None.
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc, self.simple_search_list_txn, table, term, col, retcols
)
@classmethod
- def simple_search_list_txn(cls, txn, table, term, col, retcols):
+ def simple_search_list_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ term: Optional[str],
+ col: str,
+ retcols: Iterable[str],
+ ) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- txn : Transaction object
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
+ txn: Transaction object
+ table: the table name
+ term: term for searching the table matched to a column.
+ col: column to query term should be matched to
+ retcols: the names of the columns to return
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
+ None if no term is given, otherwise a list of dictionaries.
"""
if term:
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
termvalues = ["%%" + term + "%%"]
txn.execute(sql, termvalues)
else:
- return 0
+ return None
return cls.cursor_to_dict(txn)
def make_in_list_sql_clause(
- database_engine, column: str, iterable: Iterable
+ database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given column is in the iterable.
|