diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 6814bf5fcf..ab49d227de 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -19,12 +19,11 @@ import random
from abc import ABCMeta
from typing import Any, Optional
-from canonicaljson import json
-
from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.types import Collection, get_domain_from_id
+from synapse.util import json_decoder
logger = logging.getLogger(__name__)
@@ -99,13 +98,13 @@ def db_to_json(db_content):
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
- # Decode it to a Unicode string before feeding it to json.loads, since
+ # Decode it to a Unicode string before feeding it to the JSON decoder, since
# Python 3.5 does not support deserializing bytes.
if isinstance(db_content, (bytes, bytearray)):
db_content = db_content.decode("utf8")
try:
- return json.loads(db_content)
+ return json_decoder.decode(db_content)
except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 90a1f9e8b1..0db900fa0e 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -16,9 +16,8 @@
import logging
from typing import Optional
-from canonicaljson import json
-
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import json_encoder
from . import engines
@@ -415,13 +414,14 @@ class BackgroundUpdater(object):
self.register_background_update_handler(update_name, updater)
- def _end_background_update(self, update_name):
+ async def _end_background_update(self, update_name: str) -> None:
"""Removes a completed background update task from the queue.
Args:
- update_name(str): The name of the completed task to remove
+ update_name:: The name of the completed task to remove
+
Returns:
- A deferred that completes once the task is removed.
+ None, completes once the task is removed.
"""
if update_name != self._current_background_update:
raise Exception(
@@ -429,7 +429,7 @@ class BackgroundUpdater(object):
% update_name
)
self._current_background_update = None
- return self.db_pool.simple_delete_one(
+ await self.db_pool.simple_delete_one(
"background_updates", keyvalues={"update_name": update_name}
)
@@ -457,7 +457,7 @@ class BackgroundUpdater(object):
progress(dict): The progress of the update.
"""
- progress_json = json.dumps(progress)
+ progress_json = json_encoder.encode(progress)
self.db_pool.simple_update_one_txn(
txn,
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 8a9e06efcf..2f6f49a4bf 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -28,9 +28,12 @@ from typing import (
Optional,
Tuple,
TypeVar,
+ Union,
+ overload,
)
from prometheus_client import Histogram
+from typing_extensions import Literal
from twisted.enterprise import adbapi
from twisted.internet import defer
@@ -125,7 +128,7 @@ class LoggingTransaction:
method.
Args:
- txn: The database transcation object to wrap.
+ txn: The database transaction object to wrap.
name: The name of this transactions for logging.
database_engine
after_callbacks: A list that callbacks will be appended to
@@ -160,7 +163,7 @@ class LoggingTransaction:
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks
- def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
+ def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
@@ -171,7 +174,9 @@ class LoggingTransaction:
assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs))
- def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+ def call_on_exception(
+ self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
+ ):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
@@ -195,7 +200,7 @@ class LoggingTransaction:
def description(self) -> Any:
return self.txn.description
- def execute_batch(self, sql, args):
+ def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
@@ -204,17 +209,17 @@ class LoggingTransaction:
for val in args:
self.execute(sql, val)
- def execute(self, sql: str, *args: Any):
+ def execute(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.execute, sql, *args)
- def executemany(self, sql: str, *args: Any):
+ def executemany(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.executemany, sql, *args)
def _make_sql_one_line(self, sql: str) -> str:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(line.strip() for line in sql.splitlines() if line.strip())
- def _do_execute(self, func, sql, *args):
+ def _do_execute(self, func, sql: str, *args: Any) -> None:
sql = self._make_sql_one_line(sql)
# TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -240,7 +245,7 @@ class LoggingTransaction:
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs)
- def close(self):
+ def close(self) -> None:
self.txn.close()
@@ -249,13 +254,13 @@ class PerformanceCounters(object):
self.current_counters = {}
self.previous_counters = {}
- def update(self, key, duration_secs):
+ def update(self, key: str, duration_secs: float) -> None:
count, cum_time = self.current_counters.get(key, (0, 0))
count += 1
cum_time += duration_secs
self.current_counters[key] = (count, cum_time)
- def interval(self, interval_duration_secs, limit=3):
+ def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
counters = []
for name, (count, cum_time) in self.current_counters.items():
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
@@ -279,6 +284,9 @@ class PerformanceCounters(object):
return top_n_counters
+R = TypeVar("R")
+
+
class DatabasePool(object):
"""Wraps a single physical database and connection pool.
@@ -327,12 +335,12 @@ class DatabasePool(object):
self._check_safe_to_upsert,
)
- def is_running(self):
+ def is_running(self) -> bool:
"""Is the database pool currently running
"""
return self._db_pool.running
- async def _check_safe_to_upsert(self):
+ async def _check_safe_to_upsert(self) -> None:
"""
Is it safe to use native UPSERT?
@@ -363,7 +371,7 @@ class DatabasePool(object):
self._check_safe_to_upsert,
)
- def start_profiling(self):
+ def start_profiling(self) -> None:
self._previous_loop_ts = monotonic_time()
def loop():
@@ -387,8 +395,15 @@ class DatabasePool(object):
self._clock.looping_call(loop, 10000)
def new_transaction(
- self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
- ):
+ self,
+ conn: Connection,
+ desc: str,
+ after_callbacks: List[_CallbackListEntry],
+ exception_callbacks: List[_CallbackListEntry],
+ func: "Callable[..., R]",
+ *args: Any,
+ **kwargs: Any
+ ) -> R:
start = monotonic_time()
txn_id = self._TXN_ID
@@ -516,14 +531,16 @@ class DatabasePool(object):
logger.warning("Starting db txn '%s' from sentinel context", desc)
try:
- result = yield self.runWithConnection(
- self.new_transaction,
- desc,
- after_callbacks,
- exception_callbacks,
- func,
- *args,
- **kwargs
+ result = yield defer.ensureDeferred(
+ self.runWithConnection(
+ self.new_transaction,
+ desc,
+ after_callbacks,
+ exception_callbacks,
+ func,
+ *args,
+ **kwargs
+ )
)
for after_callback, after_args, after_kwargs in after_callbacks:
@@ -535,8 +552,9 @@ class DatabasePool(object):
return result
- @defer.inlineCallbacks
- def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
+ async def runWithConnection(
+ self, func: "Callable[..., R]", *args: Any, **kwargs: Any
+ ) -> R:
"""Wraps the .runWithConnection() method on the underlying db_pool.
Arguments:
@@ -547,7 +565,7 @@ class DatabasePool(object):
kwargs: named args to pass to `func`
Returns:
- Deferred: The result of func
+ The result of func
"""
parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
if not parent_context:
@@ -570,18 +588,16 @@ class DatabasePool(object):
return func(conn, *args, **kwargs)
- result = yield make_deferred_yieldable(
+ return await make_deferred_yieldable(
self._db_pool.runWithConnection(inner_func, *args, **kwargs)
)
- return result
-
@staticmethod
- def cursor_to_dict(cursor):
+ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
"""Converts a SQL cursor into an list of dicts.
Args:
- cursor : The DBAPI cursor which has executed a query.
+ cursor: The DBAPI cursor which has executed a query.
Returns:
A list of dicts where the key is the column header.
"""
@@ -589,7 +605,13 @@ class DatabasePool(object):
results = [dict(zip(col_headers, row)) for row in cursor]
return results
- def execute(self, desc, decoder, query, *args):
+ async def execute(
+ self,
+ desc: str,
+ decoder: Optional[Callable[[Cursor], R]],
+ query: str,
+ *args: Any
+ ) -> R:
"""Runs a single query for a result set.
Args:
@@ -608,25 +630,30 @@ class DatabasePool(object):
else:
return txn.fetchall()
- return self.runInteraction(desc, interaction)
+ return await self.runInteraction(desc, interaction)
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
- async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+ async def simple_insert(
+ self,
+ table: str,
+ values: Dict[str, Any],
+ or_ignore: bool = False,
+ desc: str = "simple_insert",
+ ) -> bool:
"""Executes an INSERT query on the named table.
Args:
- table : string giving the table name
- values : dict of new column names and values for them
- or_ignore : bool stating whether an exception should be raised
+ table: string giving the table name
+ values: dict of new column names and values for them
+ or_ignore: bool stating whether an exception should be raised
when a conflicting row already exists. If True, False will be
returned by the function instead
- desc : string giving a description of the transaction
+ desc: string giving a description of the transaction
Returns:
- bool: Whether the row was inserted or not. Only useful when
- `or_ignore` is True
+ Whether the row was inserted or not. Only useful when `or_ignore` is True
"""
try:
await self.runInteraction(desc, self.simple_insert_txn, table, values)
@@ -639,7 +666,9 @@ class DatabasePool(object):
return True
@staticmethod
- def simple_insert_txn(txn, table, values):
+ def simple_insert_txn(
+ txn: LoggingTransaction, table: str, values: Dict[str, Any]
+ ) -> None:
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -650,11 +679,30 @@ class DatabasePool(object):
txn.execute(sql, vals)
- def simple_insert_many(self, table, values, desc):
- return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
+ async def simple_insert_many(
+ self, table: str, values: List[Dict[str, Any]], desc: str
+ ) -> None:
+ """Executes an INSERT query on the named table.
+
+ Args:
+ table: string giving the table name
+ values: dict of new column names and values for them
+ desc: string giving a description of the transaction
+ """
+ await self.runInteraction(desc, self.simple_insert_many_txn, table, values)
@staticmethod
- def simple_insert_many_txn(txn, table, values):
+ def simple_insert_many_txn(
+ txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
+ ) -> None:
+ """Executes an INSERT query on the named table.
+
+ Args:
+ txn: The transaction to use.
+ table: string giving the table name
+ values: dict of new column names and values for them
+ desc: string giving a description of the transaction
+ """
if not values:
return
@@ -684,13 +732,13 @@ class DatabasePool(object):
async def simple_upsert(
self,
- table,
- keyvalues,
- values,
- insertion_values={},
- desc="simple_upsert",
- lock=True,
- ):
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ desc: str = "simple_upsert",
+ lock: bool = True,
+ ) -> Optional[bool]:
"""
`lock` should generally be set to True (the default), but can be set
@@ -704,16 +752,14 @@ class DatabasePool(object):
this table.
Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key columns and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
+ table: The table to upsert into
+ keyvalues: The unique key columns and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
+ lock: True to lock the table when doing the upsert.
Returns:
- None or bool: Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
+ Native upserts always return None. Emulated upserts return True if a
+ new entry was created, False if an existing one was updated.
"""
attempts = 0
while True:
@@ -740,29 +786,34 @@ class DatabasePool(object):
)
def simple_upsert_txn(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ lock: bool = True,
+ ) -> Optional[bool]:
"""
Pick the UPSERT method which works best on the platform. Either the
native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
Args:
txn: The transaction to use.
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
+ table: The table to upsert into
+ keyvalues: The unique key tables and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
+ lock: True to lock the table when doing the upsert.
Returns:
- None or bool: Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
+ Native upserts always return None. Emulated upserts return True if a
+ new entry was created, False if an existing one was updated.
"""
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
- return self.simple_upsert_txn_native_upsert(
+ self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
)
+ return None
else:
return self.simple_upsert_txn_emulated(
txn,
@@ -774,18 +825,23 @@ class DatabasePool(object):
)
def simple_upsert_txn_emulated(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ lock: bool = True,
+ ) -> bool:
"""
Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
+ table: The table to upsert into
+ keyvalues: The unique key tables and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
+ lock: True to lock the table when doing the upsert.
Returns:
- bool: Return True if a new entry was created, False if an existing
+ Returns True if a new entry was created, False if an existing
one was updated.
"""
# We need to lock the table :(, unless we're *really* careful
@@ -843,19 +899,21 @@ class DatabasePool(object):
return True
def simple_upsert_txn_native_upsert(
- self, txn, table, keyvalues, values, insertion_values={}
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ ) -> None:
"""
Use the native UPSERT functionality in recent PostgreSQL versions.
Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- Returns:
- None
+ table: The table to upsert into
+ keyvalues: The unique key tables and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
"""
allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
@@ -985,41 +1043,70 @@ class DatabasePool(object):
return txn.execute_batch(sql, args)
- def simple_select_one(
- self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
- ):
+ @overload
+ async def simple_select_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: Literal[False] = False,
+ desc: str = "simple_select_one",
+ ) -> Dict[str, Any]:
+ ...
+
+ @overload
+ async def simple_select_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: Literal[True] = True,
+ desc: str = "simple_select_one",
+ ) -> Optional[Dict[str, Any]]:
+ ...
+
+ async def simple_select_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: bool = False,
+ desc: str = "simple_select_one",
+ ) -> Optional[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcols : list of strings giving the names of the columns to return
-
- allow_none : If true, return None instead of failing if the SELECT
- statement returns no rows
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ retcols: list of strings giving the names of the columns to return
+ allow_none: If true, return None instead of failing if the SELECT
+ statement returns no rows
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)
- def simple_select_one_onecol(
+ async def simple_select_one_onecol(
self,
- table,
- keyvalues,
- retcol,
- allow_none=False,
- desc="simple_select_one_onecol",
- ):
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ allow_none: bool = False,
+ desc: str = "simple_select_one_onecol",
+ ) -> Optional[Any]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcol : string giving the name of the column to return
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ retcol: string giving the name of the column to return
+ allow_none: If true, return None instead of failing if the SELECT
+ statement returns no rows
+ desc: description of the transaction, for logging and metrics
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc,
self.simple_select_one_onecol_txn,
table,
@@ -1030,8 +1117,13 @@ class DatabasePool(object):
@classmethod
def simple_select_one_onecol_txn(
- cls, txn, table, keyvalues, retcol, allow_none=False
- ):
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ allow_none: bool = False,
+ ) -> Optional[Any]:
ret = cls.simple_select_onecol_txn(
txn, table=table, keyvalues=keyvalues, retcol=retcol
)
@@ -1045,7 +1137,12 @@ class DatabasePool(object):
raise StoreError(404, "No row found")
@staticmethod
- def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+ def simple_select_onecol_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ ) -> List[Any]:
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
if keyvalues:
@@ -1056,53 +1153,70 @@ class DatabasePool(object):
return [r[0] for r in txn]
- def simple_select_onecol(
- self, table, keyvalues, retcol, desc="simple_select_onecol"
- ):
+ async def simple_select_onecol(
+ self,
+ table: str,
+ keyvalues: Optional[Dict[str, Any]],
+ retcol: str,
+ desc: str = "simple_select_onecol",
+ ) -> List[Any]:
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
Args:
- table (str): table name
- keyvalues (dict|None): column names and values to select the rows with
- retcol (str): column whos value we wish to retrieve.
+ table: table name
+ keyvalues: column names and values to select the rows with
+ retcol: column whos value we wish to retrieve.
Returns:
- Deferred: Results in a list
+ Results in a list
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc, self.simple_select_onecol_txn, table, keyvalues, retcol
)
- def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+ async def simple_select_list(
+ self,
+ table: str,
+ keyvalues: Optional[Dict[str, Any]],
+ retcols: Iterable[str],
+ desc: str = "simple_select_list",
+ ) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- table (str): the table name
- keyvalues (dict[str, Any] | None):
+ table: the table name
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
+ retcols: the names of the columns to return
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ A list of dictionaries.
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc, self.simple_select_list_txn, table, keyvalues, retcols
)
@classmethod
- def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+ def simple_select_list_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Optional[Dict[str, Any]],
+ retcols: Iterable[str],
+ ) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- txn : Transaction object
- table (str): the table name
- keyvalues (dict[str, T] | None):
+ txn: Transaction object
+ table: the table name
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
+ retcols: the names of the columns to return
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1119,25 +1233,25 @@ class DatabasePool(object):
async def simple_select_many_batch(
self,
- table,
- column,
- iterable,
- retcols,
- keyvalues={},
- desc="simple_select_many_batch",
- batch_size=100,
- ):
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ retcols: Iterable[str],
+ keyvalues: Dict[str, Any] = {},
+ desc: str = "simple_select_many_batch",
+ batch_size: int = 100,
+ ) -> List[Any]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Filters rows by if value of `column` is in `iterable`.
Args:
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ keyvalues: dict of column names and values to select the rows with
+ retcols: list of strings giving the names of the columns to return
"""
results = [] # type: List[Dict[str, Any]]
@@ -1166,19 +1280,27 @@ class DatabasePool(object):
return results
@classmethod
- def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+ def simple_select_many_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ ) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Filters rows by if value of `column` is in `iterable`.
Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
+ txn: Transaction object
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ keyvalues: dict of column names and values to select the rows with
+ retcols: list of strings giving the names of the columns to return
"""
if not iterable:
return []
@@ -1199,13 +1321,24 @@ class DatabasePool(object):
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
- def simple_update(self, table, keyvalues, updatevalues, desc):
- return self.runInteraction(
+ async def simple_update(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ desc: str,
+ ) -> int:
+ return await self.runInteraction(
desc, self.simple_update_txn, table, keyvalues, updatevalues
)
@staticmethod
- def simple_update_txn(txn, table, keyvalues, updatevalues):
+ def simple_update_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ ) -> int:
if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
else:
@@ -1221,32 +1354,33 @@ class DatabasePool(object):
return txn.rowcount
- def simple_update_one(
- self, table, keyvalues, updatevalues, desc="simple_update_one"
- ):
+ async def simple_update_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ desc: str = "simple_update_one",
+ ) -> None:
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- updatevalues : dict giving column names and values to update
- retcols : optional list of column names to return
-
- If present, retcols gives a list of column names on which to perform
- a SELECT statement *before* performing the UPDATE statement. The values
- of these will be returned in a dict.
-
- These are performed within the same transaction, allowing an atomic
- get-and-set. This can be used to implement compare-and-set by putting
- the update column in the 'keyvalues' dict as well.
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ updatevalues: dict giving column names and values to update
"""
- return self.runInteraction(
+ await self.runInteraction(
desc, self.simple_update_one_txn, table, keyvalues, updatevalues
)
@classmethod
- def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+ def simple_update_one_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ ) -> None:
rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
if rowcount == 0:
@@ -1254,8 +1388,18 @@ class DatabasePool(object):
if rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
+ # Ideally we could use the overload decorator here to specify that the
+ # return type is only optional if allow_none is True, but this does not work
+ # when you call a static method from an instance.
+ # See https://github.com/python/mypy/issues/7781
@staticmethod
- def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+ def simple_select_one_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: bool = False,
+ ) -> Optional[Dict[str, Any]]:
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
@@ -1274,24 +1418,28 @@ class DatabasePool(object):
return dict(zip(retcols, row))
- def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+ async def simple_delete_one(
+ self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
+ ) -> None:
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
"""
- return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
+ await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
@staticmethod
- def simple_delete_one_txn(txn, table, keyvalues):
+ def simple_delete_one_txn(
+ txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+ ) -> None:
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
@@ -1304,11 +1452,13 @@ class DatabasePool(object):
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
- def simple_delete(self, table, keyvalues, desc):
+ def simple_delete(self, table: str, keyvalues: Dict[str, Any], desc: str):
return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
@staticmethod
- def simple_delete_txn(txn, table, keyvalues):
+ def simple_delete_txn(
+ txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+ ) -> int:
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k,) for k in keyvalues),
@@ -1317,26 +1467,39 @@ class DatabasePool(object):
txn.execute(sql, list(keyvalues.values()))
return txn.rowcount
- def simple_delete_many(self, table, column, iterable, keyvalues, desc):
- return self.runInteraction(
+ async def simple_delete_many(
+ self,
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ keyvalues: Dict[str, Any],
+ desc: str,
+ ) -> int:
+ return await self.runInteraction(
desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
)
@staticmethod
- def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+ def simple_delete_many_txn(
+ txn: LoggingTransaction,
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ keyvalues: Dict[str, Any],
+ ) -> int:
"""Executes a DELETE query on the named table.
Filters rows by if value of `column` is in `iterable`.
Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
+ txn: Transaction object
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ keyvalues: dict of column names and values to select the rows with
Returns:
- int: Number rows deleted
+ Number rows deleted
"""
if not iterable:
return 0
@@ -1357,8 +1520,14 @@ class DatabasePool(object):
return txn.rowcount
def get_cache_dict(
- self, db_conn, table, entity_column, stream_column, max_value, limit=100000
- ):
+ self,
+ db_conn: Connection,
+ table: str,
+ entity_column: str,
+ stream_column: str,
+ max_value: int,
+ limit: int = 100000,
+ ) -> Tuple[Dict[Any, int], int]:
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
@@ -1389,65 +1558,19 @@ class DatabasePool(object):
return cache, min_val
- def simple_select_list_paginate(
- self,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=None,
- keyvalues=None,
- order_direction="ASC",
- desc="simple_select_list_paginate",
- ):
- """
- Executes a SELECT query on the named table with start and limit,
- of row numbers, which may return zero or number of rows from start to limit,
- returning the result as a list of dicts.
-
- Args:
- table (str): the table name
- filters (dict[str, T] | None):
- column names and values to filter the rows with, or None to not
- apply a WHERE ? LIKE ? clause.
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- return self.runInteraction(
- desc,
- self.simple_select_list_paginate_txn,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=filters,
- keyvalues=keyvalues,
- order_direction=order_direction,
- )
-
@classmethod
def simple_select_list_paginate_txn(
cls,
- txn,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=None,
- keyvalues=None,
- order_direction="ASC",
- ):
+ txn: LoggingTransaction,
+ table: str,
+ orderby: str,
+ start: int,
+ limit: int,
+ retcols: Iterable[str],
+ filters: Optional[Dict[str, Any]] = None,
+ keyvalues: Optional[Dict[str, Any]] = None,
+ order_direction: str = "ASC",
+ ) -> List[Dict[str, Any]]:
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
@@ -1458,21 +1581,22 @@ class DatabasePool(object):
using 'AND'.
Args:
- txn : Transaction object
- table (str): the table name
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- filters (dict[str, T] | None):
+ txn: Transaction object
+ table: the table name
+ orderby: Column to order the results by.
+ start: Index to begin the query at.
+ limit: Number of results to return.
+ retcols: the names of the columns to return
+ filters:
column names and values to filter the rows with, or None to not
apply a WHERE ? LIKE ? clause.
- keyvalues (dict[str, T] | None):
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+ order_direction: Whether the results should be ordered "ASC" or "DESC".
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ The result as a list of dictionaries.
"""
if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
@@ -1498,38 +1622,52 @@ class DatabasePool(object):
return cls.cursor_to_dict(txn)
- def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+ async def simple_search_list(
+ self,
+ table: str,
+ term: Optional[str],
+ col: str,
+ retcols: Iterable[str],
+ desc="simple_search_list",
+ ) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
+ table: the table name
+ term: term for searching the table matched to a column.
+ col: column to query term should be matched to
+ retcols: the names of the columns to return
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
+ A list of dictionaries or None.
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc, self.simple_search_list_txn, table, term, col, retcols
)
@classmethod
- def simple_search_list_txn(cls, txn, table, term, col, retcols):
+ def simple_search_list_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ term: Optional[str],
+ col: str,
+ retcols: Iterable[str],
+ ) -> Union[List[Dict[str, Any]], int]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- txn : Transaction object
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
+ txn: Transaction object
+ table: the table name
+ term: term for searching the table matched to a column.
+ col: column to query term should be matched to
+ retcols: the names of the columns to return
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
+ 0 if no term is given, otherwise a list of dictionaries.
"""
if term:
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
@@ -1542,7 +1680,7 @@ class DatabasePool(object):
def make_in_list_sql_clause(
- database_engine, column: str, iterable: Iterable
+ database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given column is in the iterable.
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 4406e58273..0ac854aee2 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -87,12 +87,21 @@ class Databases(object):
logger.info("Database %r prepared", db_name)
+ # Closing the context manager doesn't close the connection.
+ # psycopg will close the connection when the object gets GCed, but *only*
+ # if the PID is the same as when the connection was opened [1], and
+ # it may not be if we fork in the meantime.
+ #
+ # [1]: https://github.com/psycopg/psycopg2/blob/2_8_5/psycopg/connection_type.c#L1378
+
+ db_conn.close()
+
# Sanity check that we have actually configured all the required stores.
if not main:
raise Exception("No 'main' data store configured")
if not state:
- raise Exception("No 'main' data store configured")
+ raise Exception("No 'state' data store configured")
# We use local variables here to ensure that the databases do not have
# optional types.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 17fa470919..70cf15dd7f 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -18,6 +18,7 @@
import calendar
import logging
import time
+from typing import Any, Dict, List, Optional
from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
@@ -476,14 +477,13 @@ class DataStore(
"generate_user_daily_visits", _generate_user_daily_visits
)
- def get_users(self):
+ async def get_users(self) -> List[Dict[str, Any]]:
"""Function to retrieve a list of users in users table.
- Args:
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ A list of dictionaries representing users.
"""
- return self.db_pool.simple_select_list(
+ return await self.db_pool.simple_select_list(
table="users",
keyvalues={},
retcols=[
@@ -498,7 +498,7 @@ class DataStore(
)
def get_users_paginate(
- self, start, limit, name=None, guests=True, deactivated=False
+ self, start, limit, user_id=None, name=None, guests=True, deactivated=False
):
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
@@ -507,7 +507,8 @@ class DataStore(
Args:
start (int): start number to begin the query from
limit (int): number of rows to retrieve
- name (string): filter for user names
+ user_id (string): search for user_id. ignored if name is not None
+ name (string): search for local part of user_id or display name
guests (bool): whether to in include guest users
deactivated (bool): whether to include deactivated users
Returns:
@@ -516,11 +517,14 @@ class DataStore(
def get_users_paginate_txn(txn):
filters = []
- args = []
+ args = [self.hs.config.server_name]
if name:
+ filters.append("(name LIKE ? OR displayname LIKE ?)")
+ args.extend(["@%" + name + "%:%", "%" + name + "%"])
+ elif user_id:
filters.append("name LIKE ?")
- args.append("%" + name + "%")
+ args.extend(["%" + user_id + "%"])
if not guests:
filters.append("is_guest = 0")
@@ -530,20 +534,23 @@ class DataStore(
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
- sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
- txn.execute(sql, args)
- count = txn.fetchone()[0]
-
- args = [self.hs.config.server_name] + args + [limit, start]
- sql = """
- SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
+ sql_base = """
FROM users as u
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
{}
- ORDER BY u.name LIMIT ? OFFSET ?
""".format(
where_clause
)
+ sql = "SELECT COUNT(*) as total_users " + sql_base
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = (
+ "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
+ + sql_base
+ + " ORDER BY u.name LIMIT ? OFFSET ?"
+ )
+ args += [limit, start]
txn.execute(sql, args)
users = self.db_pool.cursor_to_dict(txn)
return users, count
@@ -552,17 +559,17 @@ class DataStore(
"get_users_paginate_txn", get_users_paginate_txn
)
- def search_users(self, term):
+ async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
"""Function to search users list for one or more users with
the matched term.
Args:
- term (str): search term
- col (str): column to query term should be matched to
+ term: search term
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ A list of dictionaries or None.
"""
- return self.db_pool.simple_search_list(
+ return await self.db_pool.simple_search_list(
table="users",
term=term,
col="name",
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 82aac2bbf3..04042a2c98 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
@@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 02568a2391..77723f7d4d 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -16,13 +16,12 @@
import logging
import re
-from canonicaljson import json
-
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util import json_encoder
logger = logging.getLogger(__name__)
@@ -204,7 +203,7 @@ class ApplicationServiceTransactionWorkerStore(
new_txn_id = max(highest_txn_id, last_txn_id) + 1
# Insert new txn into txn table
- event_ids = json.dumps([e.event_id for e in events])
+ event_ids = json_encoder.encode([e.event_id for e in events])
txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 10de446065..1e7637a6f5 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
},
)
- def get_cache_stream_token(self, instance_name):
+ def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
if self._cache_id_gen:
- return self._cache_id_gen.get_current_token(instance_name)
+ return self._cache_id_gen.get_current_token_for_writer(instance_name)
else:
return 0
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1f6e995c4f..bb85637a95 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
- with self._device_inbox_id_gen.get_next() as stream_id:
+ with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device
)
- with self._device_inbox_id_gen.get_next() as stream_id:
+ with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9a786e2929..a811a39eb5 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
- def get_device(self, user_id: str, device_id: str):
+ async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
@@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore):
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
- defer.Deferred for a dict containing the device information
+ A dict containing the device information
Raises:
StoreError: if the device is not found
"""
- return self.db_pool.simple_select_one(
+ return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID.
"""
- with self._device_list_id_gen.get_next() as stream_id:
+ with await self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
- def get_device_list_last_stream_id_for_remote(self, user_id: str):
+ async def get_device_list_last_stream_id_for_remote(
+ self, user_id: str
+ ) -> Optional[Any]:
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
@@ -1146,7 +1148,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return
- with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+ with await self._device_list_id_gen.get_next_mult(
+ len(device_ids)
+ ) as stream_ids:
await self.db_pool.runInteraction(
"add_device_change_to_stream",
self._add_device_change_to_stream_txn,
@@ -1159,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
context = get_active_span_text_map()
- with self._device_list_id_gen.get_next_mult(
+ with await self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 037e02603c..405b5eafa5 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -14,7 +14,7 @@
# limitations under the License.
from collections import namedtuple
-from typing import Iterable, Optional
+from typing import Iterable, List, Optional
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
@@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore):
return RoomAliasMapping(room_id, room_alias.to_string(), servers)
- def get_room_alias_creator(self, room_alias):
- return self.db_pool.simple_select_one_onecol(
+ async def get_room_alias_creator(self, room_alias: str) -> str:
+ return await self.db_pool.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",
@@ -68,8 +68,8 @@ class DirectoryWorkerStore(SQLBaseStore):
)
@cached(max_entries=5000)
- def get_aliases_for_room(self, room_id):
- return self.db_pool.simple_select_onecol(
+ async def get_aliases_for_room(self, room_id: str) -> List[str]:
+ return await self.db_pool.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 2eeb9f97dc..82f9d870fd 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -223,15 +225,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return ret
- def count_e2e_room_keys(self, user_id, version):
+ async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
"""Get the number of keys in a backup version.
Args:
- user_id (str): the user whose backup we're querying
- version (str): the version ID of the backup we're querying about
+ user_id: the user whose backup we're querying
+ version: the version ID of the backup we're querying about
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version},
retcol="COUNT(*)",
@@ -368,18 +370,22 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@trace
- def update_e2e_room_keys_version(
- self, user_id, version, info=None, version_etag=None
- ):
+ async def update_e2e_room_keys_version(
+ self,
+ user_id: str,
+ version: str,
+ info: Optional[dict] = None,
+ version_etag: Optional[int] = None,
+ ) -> None:
"""Update a given backup version
Args:
- user_id(str): the user whose backup version we're updating
- version(str): the version ID of the backup version we're updating
- info (dict): the new backup version info to store. If None, then
- the backup version info is not updated
- version_etag (Optional[int]): etag of the keys in the backup. If
- None, then the etag is not updated
+ user_id: the user whose backup version we're updating
+ version: the version ID of the backup version we're updating
+ info: the new backup version info to store. If None, then the backup
+ version info is not updated.
+ version_etag: etag of the keys in the backup. If None, then the etag
+ is not updated.
"""
updatevalues = {}
@@ -389,7 +395,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
updatevalues["etag"] = version_etag
if updatevalues:
- return self.db_pool.simple_update(
+ await self.db_pool.simple_update(
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version},
updatevalues=updatevalues,
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f93e0d320d..af0b85e2c9 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
@@ -27,6 +27,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
+if TYPE_CHECKING:
+ from synapse.handlers.e2e_keys import SignatureListItem
+
class EndToEndKeyWorkerStore(SQLBaseStore):
@trace
@@ -648,7 +651,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
- def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
+ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
"""Set a user's cross-signing key.
Args:
@@ -658,6 +661,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
key (dict): the key data
+ stream_id (int)
"""
# the 'key' dict will look something like:
# {
@@ -695,23 +699,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
# and finally, store the key itself
- with self._cross_signing_id_gen.get_next() as stream_id:
- self.db_pool.simple_insert_txn(
- txn,
- "e2e_cross_signing_keys",
- values={
- "user_id": user_id,
- "keytype": key_type,
- "keydata": json_encoder.encode(key),
- "stream_id": stream_id,
- },
- )
+ self.db_pool.simple_insert_txn(
+ txn,
+ "e2e_cross_signing_keys",
+ values={
+ "user_id": user_id,
+ "keytype": key_type,
+ "keydata": json_encoder.encode(key),
+ "stream_id": stream_id,
+ },
+ )
self._invalidate_cache_and_stream(
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
)
- def set_e2e_cross_signing_key(self, user_id, key_type, key):
+ async def set_e2e_cross_signing_key(self, user_id, key_type, key):
"""Set a user's cross-signing key.
Args:
@@ -719,22 +722,27 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key_type (str): the type of cross-signing key to set
key (dict): the key data
"""
- return self.db_pool.runInteraction(
- "add_e2e_cross_signing_key",
- self._set_e2e_cross_signing_key_txn,
- user_id,
- key_type,
- key,
- )
- def store_e2e_cross_signing_signatures(self, user_id, signatures):
+ with await self._cross_signing_id_gen.get_next() as stream_id:
+ return await self.db_pool.runInteraction(
+ "add_e2e_cross_signing_key",
+ self._set_e2e_cross_signing_key_txn,
+ user_id,
+ key_type,
+ key,
+ stream_id,
+ )
+
+ async def store_e2e_cross_signing_signatures(
+ self, user_id: str, signatures: "Iterable[SignatureListItem]"
+ ) -> None:
"""Stores cross-signing signatures.
Args:
- user_id (str): the user who made the signatures
- signatures (iterable[SignatureListItem]): signatures to add
+ user_id: the user who made the signatures
+ signatures: signatures to add
"""
- return self.db_pool.simple_insert_many(
+ await self.db_pool.simple_insert_many(
"e2e_cross_signing_signatures",
[
{
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 431bd76693..6e5761c7b7 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -15,14 +15,16 @@
import itertools
import logging
from queue import Empty, PriorityQueue
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Set, Tuple
from synapse.api.errors import StoreError
+from synapse.events import EventBase
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.types import Collection
from synapse.util.caches.descriptors import cached
from synapse.util.iterutils import batch_iter
@@ -30,57 +32,51 @@ logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
- def get_auth_chain(self, event_ids, include_given=False):
+ async def get_auth_chain(
+ self, event_ids: Collection[str], include_given: bool = False
+ ) -> List[EventBase]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
- event_ids (list): state events
- include_given (bool): include the given events in result
+ event_ids: state events
+ include_given: include the given events in result
Returns:
list of events
"""
- return self.get_auth_chain_ids(
+ event_ids = await self.get_auth_chain_ids(
event_ids, include_given=include_given
- ).addCallback(self.get_events_as_list)
-
- def get_auth_chain_ids(
- self,
- event_ids: List[str],
- include_given: bool = False,
- ignore_events: Optional[Set[str]] = None,
- ):
+ )
+ return await self.get_events_as_list(event_ids)
+
+ async def get_auth_chain_ids(
+ self, event_ids: Collection[str], include_given: bool = False,
+ ) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
event_ids: state events
include_given: include the given events in result
- ignore_events: Set of events to exclude from the returned auth
- chain. This is useful if the caller will just discard the
- given events anyway, and saves us from figuring out their auth
- chains if not required.
Returns:
list of event_ids
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
event_ids,
include_given,
- ignore_events,
)
- def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
- if ignore_events is None:
- ignore_events = set()
-
+ def _get_auth_chain_ids_txn(
+ self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
+ ) -> List[str]:
if include_given:
results = set(event_ids)
else:
results = set()
- base_sql = "SELECT auth_id FROM event_auth WHERE "
+ base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
front = set(event_ids)
while front:
@@ -92,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(base_sql + clause, args)
new_front.update(r[0] for r in txn)
- new_front -= ignore_events
new_front -= results
front = new_front
@@ -373,8 +368,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
@cached(max_entries=5000, iterable=True)
- def get_latest_event_ids_in_room(self, room_id):
- return self.db_pool.simple_select_onecol(
+ async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
+ return await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
@@ -459,7 +454,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
- def get_backfill_events(self, room_id, event_list, limit):
+ async def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
@@ -469,17 +464,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_list (list)
limit (int)
"""
- return (
- self.db_pool.runInteraction(
- "get_backfill_events",
- self._get_backfill_events,
- room_id,
- event_list,
- limit,
- )
- .addCallback(self.get_events_as_list)
- .addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
+ event_ids = await self.db_pool.runInteraction(
+ "get_backfill_events",
+ self._get_backfill_events,
+ room_id,
+ event_list,
+ limit,
)
+ events = await self.get_events_as_list(event_ids)
+ return sorted(events, key=lambda e: -e.depth)
def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
@@ -540,8 +533,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events,
limit,
)
- events = await self.get_events_as_list(ids)
- return events
+ return await self.get_events_as_list(ids)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b90e6de2d5..6313b41eef 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -153,11 +153,11 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
if backfilled:
- stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+ stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
- stream_ordering_manager = self._stream_id_gen.get_next_mult(
+ stream_ordering_manager = await self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 8c63a0dc4d..e6247d682d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -19,9 +19,10 @@ import itertools
import logging
import threading
from collections import namedtuple
-from typing import List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple, overload
from constantly import NamedConstant, Names
+from typing_extensions import Literal
from twisted.internet import defer
@@ -32,7 +33,7 @@ from synapse.api.room_versions import (
EventFormatVersions,
RoomVersions,
)
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
+from synapse.types import Collection, get_domain_from_id
+from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -112,33 +113,58 @@ class EventsWorkerStore(SQLBaseStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME:
- self._stream_id_gen.advance(token)
+ self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
- self._backfill_id_gen.advance(-token)
+ self._backfill_id_gen.advance(instance_name, -token)
super().process_replication_rows(stream_name, instance_name, token, rows)
- def get_received_ts(self, event_id):
+ async def get_received_ts(self, event_id: str) -> Optional[int]:
"""Get received_ts (when it was persisted) for the event.
Raises an exception for unknown events.
Args:
- event_id (str)
+ event_id: The event ID to query.
Returns:
- Deferred[int|None]: Timestamp in milliseconds, or None for events
- that were persisted before received_ts was implemented.
+ Timestamp in milliseconds, or None for events that were persisted
+ before received_ts was implemented.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="received_ts",
desc="get_received_ts",
)
- @defer.inlineCallbacks
- def get_event(
+ # Inform mypy that if allow_none is False (the default) then get_event
+ # always returns an EventBase.
+ @overload
+ async def get_event(
+ self,
+ event_id: str,
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: Literal[False] = False,
+ check_room_id: Optional[str] = None,
+ ) -> EventBase:
+ ...
+
+ @overload
+ async def get_event(
+ self,
+ event_id: str,
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: Literal[True] = False,
+ check_room_id: Optional[str] = None,
+ ) -> Optional[EventBase]:
+ ...
+
+ async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
@@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected: bool = False,
allow_none: bool = False,
check_room_id: Optional[str] = None,
- ):
+ ) -> Optional[EventBase]:
"""Get an event from the database by event_id.
Args:
@@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none.
Returns:
- Deferred[EventBase|None]
+ The event, or None if the event was not found.
"""
if not isinstance(event_id, str):
raise TypeError("Invalid event event_id %r" % (event_id,))
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[event_id],
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
return event
- @defer.inlineCallbacks
- def get_events(
+ async def get_events(
self,
- event_ids: List[str],
+ event_ids: Iterable[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
- ):
+ ) -> Dict[str, EventBase]:
"""Get events from the database
Args:
@@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
omits rejeted events from the response.
Returns:
- Deferred : Dict from event_id to event.
+ A mapping from event_id to event.
"""
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
event_ids,
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event_id: e for e in events}
- @defer.inlineCallbacks
- def get_events_as_list(
+ async def get_events_as_list(
self,
- event_ids: List[str],
+ event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
- ):
+ ) -> List[EventBase]:
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
@@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
omits rejected events from the response.
Returns:
- Deferred[list[EventBase]]: List of events fetched from the database. The
- events are in the same order as `event_ids` arg.
+ List of events fetched from the database. The events are in the same
+ order as `event_ids` arg.
Note that the returned list may be smaller than the list of event
IDs if not all events could be fetched.
@@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
return []
# there may be duplicates so we cast the list to a set
- event_entry_map = yield self._get_events_from_cache_or_db(
+ event_entry_map = await self._get_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected
)
@@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
continue
redacted_event_id = entry.event.redacts
- event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
+ event_map = await self._get_events_from_cache_or_db([redacted_event_id])
original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry:
# we don't have the redacted event (or it was rejected).
@@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
if get_prev_content:
if "replaces_state" in event.unsigned:
- prev = yield self.get_event(
+ prev = await self.get_event(
event.unsigned["replaces_state"],
get_prev_content=False,
allow_none=True,
@@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
return events
- @defer.inlineCallbacks
- def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
- Deferred[Dict[str, _EventCacheEntry]]:
+ Dict[str, _EventCacheEntry]:
map from event id to result
"""
event_entry_map = self._get_events_from_cache(
@@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
- missing_events = yield self._get_events_from_db(
+ missing_events = await self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected
)
@@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
- @defer.inlineCallbacks
- def _get_events_from_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the database.
Returned events will be added to the cache for future lookups.
@@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
- Deferred[Dict[str, _EventCacheEntry]]:
+ Dict[str, _EventCacheEntry]:
map from event id to result. May return extra events which
weren't asked for.
"""
@@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
events_to_fetch = event_ids
while events_to_fetch:
- row_map = yield self._enqueue_events(events_to_fetch)
+ row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids = set()
@@ -574,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected and rejected_reason:
continue
- d = db_to_json(row["json"])
- internal_metadata = db_to_json(row["internal_metadata"])
+ # If the event or metadata cannot be parsed, log the error and act
+ # as if the event is unknown.
+ try:
+ d = db_to_json(row["json"])
+ except ValueError:
+ logger.error("Unable to parse json from event: %s", event_id)
+ continue
+ try:
+ internal_metadata = db_to_json(row["internal_metadata"])
+ except ValueError:
+ logger.error(
+ "Unable to parse internal_metadata from event: %s", event_id
+ )
+ continue
format_version = row["format_version"]
if format_version is None:
@@ -586,19 +620,38 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row["room_version_id"]
if not room_version_id:
- # this should only happen for out-of-band membership events
- if not internal_metadata.get("out_of_band_membership"):
- logger.warning(
- "Room %s for event %s is unknown", d["room_id"], event_id
+ # this should only happen for out-of-band membership events which
+ # arrived before #6983 landed. For all other events, we should have
+ # an entry in the 'rooms' table.
+ #
+ # However, the 'out_of_band_membership' flag is unreliable for older
+ # invites, so just accept it for all membership events.
+ #
+ if d["type"] != EventTypes.Member:
+ raise Exception(
+ "Room %s for event %s is unknown" % (d["room_id"], event_id)
)
- continue
- # take a wild stab at the room version based on the event format
+ # so, assuming this is an out-of-band-invite that arrived before #6983
+ # landed, we know that the room version must be v5 or earlier (because
+ # v6 hadn't been invented at that point, so invites from such rooms
+ # would have been rejected.)
+ #
+ # The main reason we need to know the room version here (other than
+ # choosing the right python Event class) is in case the event later has
+ # to be redacted - and all the room versions up to v5 used the same
+ # redaction algorithm.
+ #
+ # So, the following approximations should be adequate.
+
if format_version == EventFormatVersions.V1:
+ # if it's event format v1 then it must be room v1 or v2
room_version = RoomVersions.V1
elif format_version == EventFormatVersions.V2:
+ # if it's event format v2 then it must be room v3
room_version = RoomVersions.V3
else:
+ # if it's event format v3 then it must be room v4 or v5
room_version = RoomVersions.V5
else:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
@@ -650,8 +703,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- @defer.inlineCallbacks
- def _enqueue_events(self, events):
+ async def _enqueue_events(self, events):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -660,7 +712,7 @@ class EventsWorkerStore(SQLBaseStore):
events (Iterable[str]): events to be fetched.
Returns:
- Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+ Dict[str, Dict]: map from event id to row data from the database.
May contain events that weren't requested.
"""
@@ -683,7 +735,7 @@ class EventsWorkerStore(SQLBaseStore):
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
- row_map = yield events_d
+ row_map = await events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
return row_map
@@ -842,33 +894,29 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
- @defer.inlineCallbacks
- def have_events_in_timeline(self, event_ids):
+ async def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
- rows = yield defer.ensureDeferred(
- self.db_pool.simple_select_many_batch(
- table="events",
- retcols=("event_id",),
- column="event_id",
- iterable=list(event_ids),
- keyvalues={"outlier": False},
- desc="have_events_in_timeline",
- )
+ rows = await self.db_pool.simple_select_many_batch(
+ table="events",
+ retcols=("event_id",),
+ column="event_id",
+ iterable=list(event_ids),
+ keyvalues={"outlier": False},
+ desc="have_events_in_timeline",
)
return {r["event_id"] for r in rows}
- @defer.inlineCallbacks
- def have_seen_events(self, event_ids):
+ async def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Args:
event_ids (iterable[str]):
Returns:
- Deferred[set[str]]: The events we have already seen.
+ set[str]: The events we have already seen.
"""
results = set()
@@ -884,7 +932,7 @@ class EventsWorkerStore(SQLBaseStore):
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
return results
@@ -914,8 +962,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
- @defer.inlineCallbacks
- def get_room_complexity(self, room_id):
+ async def get_room_complexity(self, room_id):
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@@ -926,9 +973,9 @@ class EventsWorkerStore(SQLBaseStore):
room_id (str)
Returns:
- Deferred[dict[str:int]] of complexity version to complexity.
+ dict[str:int] of complexity version to complexity.
"""
- state_events = yield self.get_current_state_event_counts(room_id)
+ state_events = await self.get_current_state_event_counts(room_id)
# Call this one "v1", so we can introduce new ones as we want to develop
# it.
@@ -1165,9 +1212,9 @@ class EventsWorkerStore(SQLBaseStore):
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
- @cachedInlineCallbacks(max_entries=5000)
- def get_event_ordering(self, event_id):
- res = yield self.db_pool.simple_select_one(
+ @cached(max_entries=5000)
+ async def get_event_ordering(self, event_id):
+ res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 380db3a3f3..e3ead71853 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = ""
class GroupServerWorkerStore(SQLBaseStore):
- def get_group(self, group_id):
- return self.db_pool.simple_select_one(
+ async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="groups",
keyvalues={"group_id": group_id},
retcols=(
@@ -44,24 +44,26 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_group",
)
- def get_users_in_group(self, group_id, include_private=False):
+ async def get_users_in_group(
+ self, group_id: str, include_private: bool = False
+ ) -> List[Dict[str, Any]]:
# TODO: Pagination
keyvalues = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
- return self.db_pool.simple_select_list(
+ return await self.db_pool.simple_select_list(
table="group_users",
keyvalues=keyvalues,
retcols=("user_id", "is_public", "is_admin"),
desc="get_users_in_group",
)
- def get_invited_users_in_group(self, group_id):
+ async def get_invited_users_in_group(self, group_id: str) -> List[str]:
# TODO: Pagination
- return self.db_pool.simple_select_onecol(
+ return await self.db_pool.simple_select_onecol(
table="group_invites",
keyvalues={"group_id": group_id},
retcol="user_id",
@@ -265,15 +267,14 @@ class GroupServerWorkerStore(SQLBaseStore):
return role
- def get_local_groups_for_room(self, room_id):
+ async def get_local_groups_for_room(self, room_id: str) -> List[str]:
"""Get all of the local group that contain a given room
Args:
- room_id (str): The ID of a room
+ room_id: The ID of a room
Returns:
- Deferred[list[str]]: A twisted.Deferred containing a list of group ids
- containing this room
+ A list of group ids containing this room
"""
- return self.db_pool.simple_select_onecol(
+ return await self.db_pool.simple_select_onecol(
table="group_rooms",
keyvalues={"room_id": room_id},
retcol="group_id",
@@ -341,17 +342,20 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_users_for_summary_by_role", _get_users_for_summary_txn
)
- def is_user_in_group(self, user_id, group_id):
- return self.db_pool.simple_select_one_onecol(
+ async def is_user_in_group(self, user_id: str, group_id: str) -> bool:
+ result = await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
allow_none=True,
desc="is_user_in_group",
- ).addCallback(lambda r: bool(r))
+ )
+ return bool(result)
- def is_user_admin_in_group(self, group_id, user_id):
- return self.db_pool.simple_select_one_onecol(
+ async def is_user_admin_in_group(
+ self, group_id: str, user_id: str
+ ) -> Optional[bool]:
+ return await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin",
@@ -359,10 +363,12 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="is_user_admin_in_group",
)
- def is_user_invited_to_local_group(self, group_id, user_id):
+ async def is_user_invited_to_local_group(
+ self, group_id: str, user_id: str
+ ) -> Optional[bool]:
"""Has the group server invited a user?
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
@@ -417,10 +423,10 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_users_membership_info_in_group", _get_users_membership_in_group_txn
)
- def get_publicised_groups_for_user(self, user_id):
+ async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
"""Get all groups a user is publicising
"""
- return self.db_pool.simple_select_onecol(
+ return await self.db_pool.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
retcol="group_id",
@@ -461,8 +467,8 @@ class GroupServerWorkerStore(SQLBaseStore):
return None
- def get_joined_groups(self, user_id):
- return self.db_pool.simple_select_onecol(
+ async def get_joined_groups(self, user_id: str) -> List[str]:
+ return await self.db_pool.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join"},
retcol="group_id",
@@ -580,14 +586,14 @@ class GroupServerWorkerStore(SQLBaseStore):
class GroupServerStore(GroupServerWorkerStore):
- def set_group_join_policy(self, group_id, join_policy):
+ async def set_group_join_policy(self, group_id: str, join_policy: str) -> None:
"""Set the join policy of a group.
join_policy can be one of:
* "invite"
* "open"
"""
- return self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues={"join_policy": join_policy},
@@ -1045,8 +1051,10 @@ class GroupServerStore(GroupServerWorkerStore):
desc="add_room_to_group",
)
- def update_room_in_group_visibility(self, group_id, room_id, is_public):
- return self.db_pool.simple_update(
+ async def update_room_in_group_visibility(
+ self, group_id: str, room_id: str, is_public: bool
+ ) -> int:
+ return await self.db_pool.simple_update(
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
updatevalues={"is_public": is_public},
@@ -1071,10 +1079,12 @@ class GroupServerStore(GroupServerWorkerStore):
"remove_room_from_group", _remove_room_from_group_txn
)
- def update_group_publicity(self, group_id, user_id, publicise):
+ async def update_group_publicity(
+ self, group_id: str, user_id: str, publicise: bool
+ ) -> None:
"""Update whether the user is publicising their membership of the group
"""
- return self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_publicised": publicise},
@@ -1181,7 +1191,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- with self._group_updates_id_gen.get_next() as next_id:
+ with await self._group_updates_id_gen.get_next() as next_id:
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
@@ -1213,20 +1223,24 @@ class GroupServerStore(GroupServerWorkerStore):
desc="update_group_profile",
)
- def update_attestation_renewal(self, group_id, user_id, attestation):
+ async def update_attestation_renewal(
+ self, group_id: str, user_id: str, attestation: dict
+ ) -> None:
"""Update an attestation that we have renewed
"""
- return self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
desc="update_attestation_renewal",
)
- def update_remote_attestion(self, group_id, user_id, attestation):
+ async def update_remote_attestion(
+ self, group_id: str, user_id: str, attestation: dict
+ ) -> None:
"""Update an attestation that a remote has renewed
"""
- return self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 384e9c5eb0..fadcad51e7 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,6 +16,7 @@
import itertools
import logging
+from typing import Iterable, Tuple
from signedjson.key import decode_verify_key_bytes
@@ -88,12 +89,17 @@ class KeyStore(SQLBaseStore):
return self.db_pool.runInteraction("get_server_verify_keys", _txn)
- def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+ async def store_server_verify_keys(
+ self,
+ from_server: str,
+ ts_added_ms: int,
+ verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
+ ) -> None:
"""Stores NACL verification keys for remote servers.
Args:
- from_server (str): Where the verification keys were looked up
- ts_added_ms (int): The time to record that the key was added
- verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+ from_server: Where the verification keys were looked up
+ ts_added_ms: The time to record that the key was added
+ verify_keys:
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
"""
@@ -115,13 +121,7 @@ class KeyStore(SQLBaseStore):
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))
- def _invalidate(res):
- f = self._get_server_verify_key.invalidate
- for i in invalidations:
- f((i,))
- return res
-
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"store_server_verify_keys",
self.db_pool.simple_upsert_many_txn,
table="server_signature_keys",
@@ -134,7 +134,11 @@ class KeyStore(SQLBaseStore):
"verify_key",
),
value_values=value_values,
- ).addCallback(_invalidate)
+ )
+
+ invalidate = self._get_server_verify_key.invalidate
+ for i in invalidations:
+ invalidate((i,))
def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 80fc1cd009..8361dd63d9 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, List, Optional
+
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
- def get_local_media(self, media_id):
+ async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media
+
Returns:
None if the media_id doesn't exist.
"""
- return self.db_pool.simple_select_one(
+ return await self.db_pool.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -81,9 +84,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_media",
)
- def mark_local_media_as_safe(self, media_id: str):
+ async def mark_local_media_as_safe(self, media_id: str) -> None:
"""Mark a local media as safe from quarantining."""
- return self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="local_media_repository",
keyvalues={"media_id": media_id},
updatevalues={"safe_from_quarantine": True},
@@ -155,8 +158,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_url_cache",
)
- def get_local_media_thumbnails(self, media_id):
- return self.db_pool.simple_select_list(
+ async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]:
+ return await self.db_pool.simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
@@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_thumbnail",
)
- def get_cached_remote_media(self, origin, media_id):
- return self.db_pool.simple_select_one(
+ async def get_cached_remote_media(
+ self, origin, media_id: str
+ ) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@@ -266,8 +271,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"update_cached_last_access_time", update_cache_txn
)
- def get_remote_media_thumbnails(self, origin, media_id):
- return self.db_pool.simple_select_list(
+ async def get_remote_media_thumbnails(
+ self, origin: str, media_id: str
+ ) -> List[Dict[str, Any]]:
+ return await self.db_pool.simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
@@ -307,14 +314,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_remote_media_thumbnail",
)
- def get_remote_media_before(self, before_ts):
+ async def get_remote_media_before(self, before_ts):
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)
- return self.db_pool.execute(
+ return await self.db_pool.execute(
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
)
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 1d4db758d4..66953ffc26 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
return users
@cached(num_args=1)
- def user_last_seen_monthly_active(self, user_id):
+ async def user_last_seen_monthly_active(self, user_id: str) -> int:
"""
- Checks if a given user is part of the monthly active user group
- Arguments:
- user_id (str): user to add/update
- Return:
- Deferred[int] : timestamp since last seen, None if never seen
+ Checks if a given user is part of the monthly active user group
+ Arguments:
+ user_id: user to add/update
+
+ Return:
+ Timestamp since last seen, None if never seen
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="monthly_active_users",
keyvalues={"user_id": user_id},
retcol="timestamp",
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 4e3ec02d14..c9f655dfb7 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
async def update_presence(self, presence_states):
- stream_ordering_manager = self._presence_id_gen.get_next_mult(
+ stream_ordering_manager = await self._presence_id_gen.get_next_mult(
len(presence_states)
)
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index b8261357d4..858fd92420 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
@@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo
class ProfileWorkerStore(SQLBaseStore):
- async def get_profileinfo(self, user_localpart):
+ async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one(
table="profiles",
@@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
- def get_profile_displayname(self, user_localpart):
- return self.db_pool.simple_select_one_onecol(
+ async def get_profile_displayname(self, user_localpart: str) -> str:
+ return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)
- def get_profile_avatar_url(self, user_localpart):
- return self.db_pool.simple_select_one_onecol(
+ async def get_profile_avatar_url(self, user_localpart: str) -> str:
+ return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
- def get_from_remote_profile_cache(self, user_id):
- return self.db_pool.simple_select_one(
+ async def get_from_remote_profile_cache(
+ self, user_id: str
+ ) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),
@@ -68,16 +71,20 @@ class ProfileWorkerStore(SQLBaseStore):
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
- def set_profile_displayname(self, user_localpart, new_displayname):
- return self.db_pool.simple_update_one(
+ async def set_profile_displayname(
+ self, user_localpart: str, new_displayname: str
+ ) -> None:
+ await self.db_pool.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname},
desc="set_profile_displayname",
)
- def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self.db_pool.simple_update_one(
+ async def set_profile_avatar_url(
+ self, user_localpart: str, new_avatar_url: str
+ ) -> None:
+ await self.db_pool.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"avatar_url": new_avatar_url},
@@ -103,8 +110,10 @@ class ProfileStore(ProfileWorkerStore):
desc="add_remote_profile_cache",
)
- def update_remote_profile_cache(self, user_id, displayname, avatar_url):
- return self.db_pool.simple_update(
+ async def update_remote_profile_cache(
+ self, user_id: str, displayname: str, avatar_url: str
+ ) -> int:
+ return await self.db_pool.simple_update(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
updatevalues={
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index c2289a9557..2fb5b02d7d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -30,7 +30,7 @@ from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import ChainedIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +82,9 @@ class PushRulesWorkerStore(
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen = ChainedIdGenerator(
- self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
- ) # type: Union[ChainedIdGenerator, SlavedIdTracker]
+ self._push_rules_stream_id_gen = StreamIdGenerator(
+ db_conn, "push_rules_stream", "stream_id"
+ ) # type: Union[StreamIdGenerator, SlavedIdTracker]
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
@@ -338,8 +338,9 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
if before or after:
await self.db_pool.runInteraction(
"_add_push_rule_relative_txn",
@@ -559,8 +560,9 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
await self.db_pool.runInteraction(
"delete_push_rule",
delete_push_rule_txn,
@@ -569,8 +571,9 @@ class PushRuleStore(PushRulesWorkerStore):
)
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
@@ -643,8 +646,9 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
await self.db_pool.runInteraction(
"set_push_rule_actions",
set_push_rule_actions_txn,
@@ -673,11 +677,5 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
- def get_push_rules_stream_token(self):
- """Get the position of the push rules stream.
- Returns a pair of a stream id for the push_rules stream and the
- room stream ordering it corresponds to."""
- return self._push_rules_stream_id_gen.get_current_token()
-
def get_max_push_rules_stream_id(self):
- return self.get_push_rules_stream_token()[0]
+ return self._push_rules_stream_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 1126fd0751..c388468273 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
last_stream_ordering,
profile_tag="",
) -> None:
- with self._pushers_id_gen.get_next() as stream_id:
+ with await self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
},
)
- with self._pushers_id_gen.get_next() as stream_id:
+ with await self._pushers_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 19ad1c056f..436f22ad2d 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -16,7 +16,7 @@
import abc
import logging
-from typing import List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from twisted.internet import defer
@@ -62,8 +62,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return {r["user_id"] for r in receipts}
@cached(num_args=2)
- def get_receipts_for_room(self, room_id, receipt_type):
- return self.db_pool.simple_select_list(
+ async def get_receipts_for_room(
+ self, room_id: str, receipt_type: str
+ ) -> List[Dict[str, Any]]:
+ return await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
retcols=("user_id", "event_id"),
@@ -71,8 +73,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
@cached(num_args=3)
- def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
- return self.db_pool.simple_select_one_onecol(
+ async def get_last_receipt_event_id_for_user(
+ self, user_id: str, room_id: str, receipt_type: str
+ ) -> Optional[str]:
+ return await self.db_pool.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
@@ -520,8 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"insert_receipt_conv", graph_to_linear
)
- stream_id_manager = self._receipts_id_gen.get_next()
- with stream_id_manager as stream_id:
+ with await self._receipts_id_gen.get_next() as stream_id:
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 068ad22b30..48bda66f3e 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,7 +17,7 @@
import logging
import re
-from typing import Awaitable, Dict, List, Optional
+from typing import Any, Awaitable, Dict, List, Optional
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@cached()
- def get_user_by_id(self, user_id):
- return self.db_pool.simple_select_one(
+ async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
@@ -578,20 +578,20 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="add_user_bound_threepid",
)
- def user_get_bound_threepids(self, user_id):
+ async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
"""Get the threepids that a user has bound to an identity server through the homeserver
The homeserver remembers where binds to an identity server occurred. Using this
method can retrieve those threepids.
Args:
- user_id (str): The ID of the user to retrieve threepids for
+ user_id: The ID of the user to retrieve threepids for
Returns:
- Deferred[list[dict]]: List of dictionaries containing the following:
+ List of dictionaries containing the following keys:
medium (str): The medium of the threepid (e.g "email")
address (str): The address of the threepid (e.g "bob@example.com")
"""
- return self.db_pool.simple_select_list(
+ return await self.db_pool.simple_select_list(
table="user_threepid_id_server",
keyvalues={"user_id": user_id},
retcols=["medium", "address"],
@@ -623,19 +623,21 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="remove_user_bound_threepid",
)
- def get_id_servers_user_bound(self, user_id, medium, address):
+ async def get_id_servers_user_bound(
+ self, user_id: str, medium: str, address: str
+ ) -> List[str]:
"""Get the list of identity servers that the server proxied bind
requests to for given user and threepid
Args:
- user_id (str)
- medium (str)
- address (str)
+ user_id: The user to query for identity servers.
+ medium: The medium to query for identity servers.
+ address: The address to query for identity servers.
Returns:
- Deferred[list[str]]: Resolves to a list of identity servers
+ A list of identity servers
"""
- return self.db_pool.simple_select_onecol(
+ return await self.db_pool.simple_select_onecol(
table="user_threepid_id_server",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
retcol="id_server",
@@ -889,6 +891,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
super(RegistrationStore, self).__init__(database, db_conn, hs)
self._account_validity = hs.config.account_validity
+ self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
if self._account_validity.enabled:
self._clock.call_later(
@@ -1258,12 +1261,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="del_user_pending_deactivation",
)
- def get_user_pending_deactivation(self):
+ async def get_user_pending_deactivation(self) -> Optional[str]:
"""
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
@@ -1302,15 +1305,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
if not row:
- raise ThreepidValidationError(400, "Unknown session_id")
+ if self._ignore_unknown_session_error:
+ # If we need to inhibit the error caused by an incorrect session ID,
+ # use None as placeholder values for the client secret and the
+ # validation timestamp.
+ # It shouldn't be an issue because they're both only checked after
+ # the token check, which should fail. And if it doesn't for some
+ # reason, the next check is on the client secret, which is NOT NULL,
+ # so we don't have to worry about the client secret matching by
+ # accident.
+ row = {"client_secret": None, "validated_at": None}
+ else:
+ raise ThreepidValidationError(400, "Unknown session_id")
+
retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"]
- if retrieved_client_secret != client_secret:
- raise ThreepidValidationError(
- 400, "This client_secret does not match the provided session_id"
- )
-
row = self.db_pool.simple_select_one_txn(
txn,
table="threepid_validation_token",
@@ -1326,6 +1336,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
expires = row["expires"]
next_link = row["next_link"]
+ if retrieved_client_secret != client_secret:
+ raise ThreepidValidationError(
+ 400, "This client_secret does not match the provided session_id"
+ )
+
# If the session is already validated, no need to revalidate
if validated_at:
return next_link
diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py
index cf9ba51205..1e361aaa9a 100644
--- a/synapse/storage/databases/main/rejections.py
+++ b/synapse/storage/databases/main/rejections.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Optional
from synapse.storage._base import SQLBaseStore
@@ -21,8 +22,8 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
- def get_rejection_reason(self, event_id):
- return self.db_pool.simple_select_one_onecol(
+ async def get_rejection_reason(self, event_id: str) -> Optional[str]:
+ return await self.db_pool.simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={"event_id": event_id},
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index aef08c7e12..66d7135413 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -21,8 +21,6 @@ from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
-from canonicaljson import json
-
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
@@ -30,6 +28,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchStore
from synapse.types import ThirdPartyInstanceID
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -74,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore):
self.config = hs.config
- def get_room(self, room_id):
+ async def get_room(self, room_id: str) -> dict:
"""Retrieve a room.
Args:
- room_id (str): The ID of the room to retrieve.
+ room_id: The ID of the room to retrieve.
Returns:
A dict containing the room information, or None if the room is unknown.
"""
- return self.db_pool.simple_select_one(
+ return await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
@@ -126,8 +125,8 @@ class RoomWorkerStore(SQLBaseStore):
"get_room_with_stats", get_room_with_stats_txn, room_id
)
- def get_public_room_ids(self):
- return self.db_pool.simple_select_onecol(
+ async def get_public_room_ids(self) -> List[str]:
+ return await self.db_pool.simple_select_onecol(
table="rooms",
keyvalues={"is_public": True},
retcol="room_id",
@@ -331,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore):
return ret_val
@cached(max_entries=10000)
- def is_room_blocked(self, room_id):
- return self.db_pool.simple_select_one_onecol(
+ async def is_room_blocked(self, room_id: str) -> Optional[bool]:
+ return await self.db_pool.simple_select_one_onecol(
table="blocked_rooms",
keyvalues={"room_id": room_id},
retcol="1",
@@ -1130,7 +1129,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
+ with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id
)
@@ -1197,7 +1196,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
+ with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
@@ -1277,7 +1276,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
+ with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
@@ -1310,7 +1309,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
"event_id": event_id,
"user_id": user_id,
"reason": reason,
- "content": json.dumps(content),
+ "content": json_encoder.encode(content),
},
desc="add_event_report",
)
diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
new file mode 100644
index 0000000000..4cc96a5341
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A table of the IP address and user-agent used to complete each step of a
+-- user-interactive authentication session.
+CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips(
+ session_id TEXT NOT NULL,
+ ip TEXT NOT NULL,
+ user_agent TEXT NOT NULL,
+ UNIQUE (session_id, ip, user_agent),
+ FOREIGN KEY (session_id)
+ REFERENCES ui_auth_sessions (session_id)
+);
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 991233a9bc..458f169617 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return event.content.get("canonical_alias")
@cached(max_entries=50000)
- def _get_state_group_for_event(self, event_id):
- return self.db_pool.simple_select_one_onecol(
+ async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
+ return await self.db_pool.simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 802c9019b9..9fe97af56a 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore):
return len(rooms_to_work_on)
- def get_stats_positions(self):
+ async def get_stats_positions(self) -> int:
"""
Returns the stats processor positions.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="stats_incremental_position",
keyvalues={},
retcol="stream_id",
@@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore):
return slice_list
@cached()
- def get_earliest_token_for_stats(self, stats_type, id):
+ async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
"""
Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the
@@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore):
being calculated.
Returns:
- Deferred[int]
+ The earliest token.
"""
table, id_col = TYPE_TO_TABLE[stats_type]
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
"%s_current" % (table,),
keyvalues={id_col: id},
retcol="completed_delta_stream_id",
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 4377bddb8c..497f607703 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
limit: int = 0,
order: str = "DESC",
) -> Tuple[List[EventBase], str]:
-
"""Get new room events in stream ordering since `from_key`.
Args:
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index e4e0a0c433..0c34bbf21a 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -17,11 +17,10 @@
import logging
from typing import Dict, List, Tuple
-from canonicaljson import json
-
from synapse.storage._base import db_to_json
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.types import JsonDict
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -98,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (user_id, room_id))
tags = []
for tag, content in txn:
- tags.append(json.dumps(tag) + ":" + content)
+ tags.append(json_encoder.encode(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, (user_id, room_id, tag_json)))
@@ -200,7 +199,7 @@ class TagsStore(TagsWorkerStore):
Returns:
The next account data ID.
"""
- content_json = json.dumps(content)
+ content_json = json_encoder.encode(content)
def add_tag_txn(txn, next_id):
self.db_pool.simple_upsert_txn(
@@ -211,7 +210,7 @@ class TagsStore(TagsWorkerStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@@ -233,7 +232,7 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 37276f73f8..9eef8e57c5 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -12,15 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import attr
-from canonicaljson import json
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
-from synapse.util import stringutils as stringutils
+from synapse.util import json_encoder, stringutils
@attr.s
@@ -72,7 +72,7 @@ class UIAuthWorkerStore(SQLBaseStore):
StoreError if a unique session ID cannot be generated.
"""
# The clientdict gets stored as JSON.
- clientdict_json = json.dumps(clientdict)
+ clientdict_json = json_encoder.encode(clientdict)
# autogen a session ID and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
@@ -143,7 +143,7 @@ class UIAuthWorkerStore(SQLBaseStore):
await self.db_pool.simple_upsert(
table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id, "stage_type": stage_type},
- values={"result": json.dumps(result)},
+ values={"result": json_encoder.encode(result)},
desc="mark_ui_auth_stage_complete",
)
except self.db_pool.engine.module.IntegrityError:
@@ -184,7 +184,7 @@ class UIAuthWorkerStore(SQLBaseStore):
The dictionary from the client root level, not the 'auth' key.
"""
# The clientdict gets stored as JSON.
- clientdict_json = json.dumps(clientdict)
+ clientdict_json = json_encoder.encode(clientdict)
await self.db_pool.simple_update_one(
table="ui_auth_sessions",
@@ -214,14 +214,16 @@ class UIAuthWorkerStore(SQLBaseStore):
value,
)
- def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
+ def _set_ui_auth_session_data_txn(
+ self, txn: LoggingTransaction, session_id: str, key: str, value: Any
+ ):
# Get the current value.
result = self.db_pool.simple_select_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
- )
+ ) # type: Dict[str, Any] # type: ignore
# Update it and add it back to the database.
serverdict = db_to_json(result["serverdict"])
@@ -231,7 +233,7 @@ class UIAuthWorkerStore(SQLBaseStore):
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
- updatevalues={"serverdict": json.dumps(serverdict)},
+ updatevalues={"serverdict": json_encoder.encode(serverdict)},
)
async def get_ui_auth_session_data(
@@ -258,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore):
return serverdict.get(key, default)
+ async def add_user_agent_ip_to_ui_auth_session(
+ self, session_id: str, user_agent: str, ip: str,
+ ):
+ """Add the given user agent / IP to the tracking table
+ """
+ await self.db_pool.simple_upsert(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
+ values={},
+ desc="add_user_agent_ip_to_ui_auth_session",
+ )
+
+ async def get_user_agents_ips_to_ui_auth_session(
+ self, session_id: str,
+ ) -> List[Tuple[str, str]]:
+ """Get the given user agents / IPs used during the ui auth process
+
+ Returns:
+ List of user_agent/ip pairs
+ """
+ rows = await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id},
+ retcols=("user_agent", "ip"),
+ desc="get_user_agents_ips_to_ui_auth_session",
+ )
+ return [(row["user_agent"], row["ip"]) for row in rows]
+
class UIAuthStore(UIAuthWorkerStore):
def delete_old_ui_auth_sessions(self, expiration_time: int):
@@ -275,12 +305,23 @@ class UIAuthStore(UIAuthWorkerStore):
expiration_time,
)
- def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
+ def _delete_old_ui_auth_sessions_txn(
+ self, txn: LoggingTransaction, expiration_time: int
+ ):
# Get the expired sessions.
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
txn.execute(sql, [expiration_time])
session_ids = [r[0] for r in txn.fetchall()]
+ # Delete the corresponding IP/user agents.
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="ui_auth_sessions_ips",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={},
+ )
+
# Delete the corresponding completed credentials.
self.db_pool.simple_delete_many_txn(
txn,
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index af21fe457a..a9f2e93614 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -15,6 +15,7 @@
import logging
import re
+from typing import Any, Dict, Optional
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool
@@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
@cached()
- def get_user_in_directory(self, user_id):
- return self.db_pool.simple_select_one(
+ async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
@@ -536,8 +537,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
desc="get_user_in_directory",
)
- def update_user_directory_stream_pos(self, stream_id):
- return self.db_pool.simple_update_one(
+ async def update_user_directory_stream_pos(self, stream_id: str) -> None:
+ await self.db_pool.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
@@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- def get_user_directory_stream_pos(self):
- return self.db_pool.simple_select_one_onecol(
+ async def get_user_directory_stream_pos(self) -> int:
+ return await self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index da23fe7355..e3547e53b3 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -13,30 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import operator
-
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
class UserErasureWorkerStore(SQLBaseStore):
@cached()
- def is_user_erased(self, user_id):
+ async def is_user_erased(self, user_id: str) -> bool:
"""
Check if the given user id has requested erasure
Args:
- user_id (str): full user id to check
+ user_id: full user id to check
Returns:
- Deferred[bool]: True if the user has requested erasure
+ True if the user has requested erasure
"""
- return self.db_pool.simple_select_onecol(
+ result = await self.db_pool.simple_select_onecol(
table="erased_users",
keyvalues={"user_id": user_id},
retcol="1",
desc="is_user_erased",
- ).addCallback(operator.truth)
+ )
+ return bool(result)
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
async def are_users_erased(self, user_ids):
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index e2ddd01290..b27a4843d0 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -14,9 +14,10 @@
# limitations under the License.
import contextlib
+import heapq
import threading
from collections import deque
-from typing import Dict, Set, Tuple
+from typing import Dict, List, Set
from typing_extensions import Deque
@@ -80,7 +81,7 @@ class StreamIdGenerator(object):
upwards, -1 to grow downwards.
Usage:
- with stream_id_gen.get_next() as stream_id:
+ with await stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@@ -95,10 +96,10 @@ class StreamIdGenerator(object):
)
self._unfinished_ids = deque() # type: Deque[int]
- def get_next(self):
+ async def get_next(self):
"""
Usage:
- with stream_id_gen.get_next() as stream_id:
+ with await stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -117,10 +118,10 @@ class StreamIdGenerator(object):
return manager()
- def get_next_mult(self, n):
+ async def get_next_mult(self, n):
"""
Usage:
- with stream_id_gen.get_next(n) as stream_ids:
+ with await stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -158,63 +159,13 @@ class StreamIdGenerator(object):
return self._current
+ def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
-class ChainedIdGenerator(object):
- """Used to generate new stream ids where the stream must be kept in sync
- with another stream. It generates pairs of IDs, the first element is an
- integer ID for this stream, the second element is the ID for the stream
- that this stream needs to be kept in sync with."""
-
- def __init__(self, chained_generator, db_conn, table, column):
- self.chained_generator = chained_generator
- self._table = table
- self._lock = threading.Lock()
- self._current_max = _load_current_id(db_conn, table, column)
- self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
-
- def get_next(self):
+ For streams with single writers this is equivalent to
+ `get_current_token`.
"""
- Usage:
- with stream_id_gen.get_next() as (stream_id, chained_id):
- # ... persist event ...
- """
- with self._lock:
- self._current_max += 1
- next_id = self._current_max
- chained_id = self.chained_generator.get_current_token()
-
- self._unfinished_ids.append((next_id, chained_id))
-
- @contextlib.contextmanager
- def manager():
- try:
- yield (next_id, chained_id)
- finally:
- with self._lock:
- self._unfinished_ids.remove((next_id, chained_id))
-
- return manager()
-
- def get_current_token(self):
- """Returns the maximum stream id such that all stream ids less than or
- equal to it have been successfully persisted.
- """
- with self._lock:
- if self._unfinished_ids:
- stream_id, chained_id = self._unfinished_ids[0]
- return stream_id - 1, chained_id
-
- return self._current_max, self.chained_generator.get_current_token()
-
- def advance(self, token: int):
- """Stub implementation for advancing the token when receiving updates
- over replication; raises an exception as this instance should be the
- only source of updates.
- """
-
- raise Exception(
- "Attempted to advance token on source for table %r", self._table
- )
+ return self.get_current_token()
class MultiWriterIdGenerator:
@@ -260,6 +211,23 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
+ # We track the max position where we know everything before has been
+ # persisted. This is done by a) looking at the min across all instances
+ # and b) noting that if we have seen a run of persisted positions
+ # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
+ #
+ # Note: There is no guarentee that the IDs generated by the sequence
+ # will be gapless; gaps can form when e.g. a transaction was rolled
+ # back. This means that sometimes we won't be able to skip forward the
+ # position even though everything has been persisted. However, since
+ # gaps should be relatively rare it's still worth doing the book keeping
+ # that allows us to skip forwards when there are gapless runs of
+ # positions.
+ self._persisted_upto_position = (
+ min(self._current_positions.values()) if self._current_positions else 0
+ )
+ self._known_persisted_positions = [] # type: List[int]
+
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
def _load_current_ids(
@@ -284,9 +252,12 @@ class MultiWriterIdGenerator:
return current_positions
- def _load_next_id_txn(self, txn):
+ def _load_next_id_txn(self, txn) -> int:
return self._sequence_gen.get_next_id_txn(txn)
+ def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
+ return self._sequence_gen.get_next_mult_txn(txn, n)
+
async def get_next(self):
"""
Usage:
@@ -298,7 +269,7 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
- assert self.get_current_token() < next_id
+ assert self.get_current_token_for_writer(self._instance_name) < next_id
with self._lock:
self._unfinished_ids.add(next_id)
@@ -312,6 +283,34 @@ class MultiWriterIdGenerator:
return manager()
+ async def get_next_mult(self, n: int):
+ """
+ Usage:
+ with await stream_id_gen.get_next_mult(5) as stream_ids:
+ # ... persist events ...
+ """
+ next_ids = await self._db.runInteraction(
+ "_load_next_mult_id", self._load_next_mult_id_txn, n
+ )
+
+ # Assert the fetched ID is actually greater than any ID we've already
+ # seen. If not, then the sequence and table have got out of sync
+ # somehow.
+ assert max(self.get_positions().values(), default=0) < min(next_ids)
+
+ with self._lock:
+ self._unfinished_ids.update(next_ids)
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield next_ids
+ finally:
+ for i in next_ids:
+ self._mark_id_as_finished(i)
+
+ return manager()
+
def get_next_txn(self, txn: LoggingTransaction):
"""
Usage:
@@ -344,16 +343,20 @@ class MultiWriterIdGenerator:
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, next_id)
- def get_current_token(self, instance_name: str = None) -> int:
- """Gets the current position of a named writer (defaults to current
- instance).
+ self._add_persisted_position(next_id)
- Returns 0 if we don't have a position for the named writer (likely due
- to it being a new writer).
+ def get_current_token(self) -> int:
+ """Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
"""
- if instance_name is None:
- instance_name = self._instance_name
+ # Currently we don't support this operation, as it's not obvious how to
+ # condense the stream positions of multiple writers into a single int.
+ raise NotImplementedError()
+
+ def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
+ """
with self._lock:
return self._current_positions.get(instance_name, 0)
@@ -374,3 +377,53 @@ class MultiWriterIdGenerator:
self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0)
)
+
+ self._add_persisted_position(new_id)
+
+ def get_persisted_upto_position(self) -> int:
+ """Get the max position where all previous positions have been
+ persisted.
+
+ Note: In the worst case scenario this will be equal to the minimum
+ position across writers. This means that the returned position here can
+ lag if one writer doesn't write very often.
+ """
+
+ with self._lock:
+ return self._persisted_upto_position
+
+ def _add_persisted_position(self, new_id: int):
+ """Record that we have persisted a position.
+
+ This is used to keep the `_current_positions` up to date.
+ """
+
+ # We require that the lock is locked by caller
+ assert self._lock.locked()
+
+ heapq.heappush(self._known_persisted_positions, new_id)
+
+ # We move the current min position up if the minimum current positions
+ # of all instances is higher (since by definition all positions less
+ # that that have been persisted).
+ min_curr = min(self._current_positions.values())
+ self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
+
+ # We now iterate through the seen positions, discarding those that are
+ # less than the current min positions, and incrementing the min position
+ # if its exactly one greater.
+ #
+ # This is also where we discard items from `_known_persisted_positions`
+ # (to ensure the list doesn't infinitely grow).
+ while self._known_persisted_positions:
+ if self._known_persisted_positions[0] <= self._persisted_upto_position:
+ heapq.heappop(self._known_persisted_positions)
+ elif (
+ self._known_persisted_positions[0] == self._persisted_upto_position + 1
+ ):
+ heapq.heappop(self._known_persisted_positions)
+ self._persisted_upto_position += 1
+ else:
+ # There was a gap in seen positions, so there is nothing more to
+ # do.
+ break
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 63dfea4220..ffc1894748 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -14,7 +14,7 @@
# limitations under the License.
import abc
import threading
-from typing import Callable, Optional
+from typing import Callable, List, Optional
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor
@@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
txn.execute("SELECT nextval(?)", (self._sequence_name,))
return txn.fetchone()[0]
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ txn.execute(
+ "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
+ )
+ return [i for (i,) in txn]
+
GetFirstCallbackType = Callable[[Cursor], int]
|