diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 90a1f9e8b1..56818f4df8 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -16,9 +16,8 @@
import logging
from typing import Optional
-from canonicaljson import json
-
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import json_encoder
from . import engines
@@ -457,7 +456,7 @@ class BackgroundUpdater(object):
progress(dict): The progress of the update.
"""
- progress_json = json.dumps(progress)
+ progress_json = json_encoder.encode(progress)
self.db_pool.simple_update_one_txn(
txn,
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b9aef96b08..bc327e344e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -28,6 +28,7 @@ from typing import (
Optional,
Tuple,
TypeVar,
+ Union,
)
from prometheus_client import Histogram
@@ -125,7 +126,7 @@ class LoggingTransaction:
method.
Args:
- txn: The database transcation object to wrap.
+ txn: The database transaction object to wrap.
name: The name of this transactions for logging.
database_engine
after_callbacks: A list that callbacks will be appended to
@@ -160,7 +161,7 @@ class LoggingTransaction:
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks
- def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
+ def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
@@ -171,7 +172,9 @@ class LoggingTransaction:
assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs))
- def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+ def call_on_exception(
+ self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
+ ):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
@@ -195,7 +198,7 @@ class LoggingTransaction:
def description(self) -> Any:
return self.txn.description
- def execute_batch(self, sql, args):
+ def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
@@ -204,17 +207,17 @@ class LoggingTransaction:
for val in args:
self.execute(sql, val)
- def execute(self, sql: str, *args: Any):
+ def execute(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.execute, sql, *args)
- def executemany(self, sql: str, *args: Any):
+ def executemany(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.executemany, sql, *args)
def _make_sql_one_line(self, sql: str) -> str:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(line.strip() for line in sql.splitlines() if line.strip())
- def _do_execute(self, func, sql, *args):
+ def _do_execute(self, func, sql: str, *args: Any) -> None:
sql = self._make_sql_one_line(sql)
# TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -240,7 +243,7 @@ class LoggingTransaction:
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs)
- def close(self):
+ def close(self) -> None:
self.txn.close()
@@ -249,13 +252,13 @@ class PerformanceCounters(object):
self.current_counters = {}
self.previous_counters = {}
- def update(self, key, duration_secs):
+ def update(self, key: str, duration_secs: float) -> None:
count, cum_time = self.current_counters.get(key, (0, 0))
count += 1
cum_time += duration_secs
self.current_counters[key] = (count, cum_time)
- def interval(self, interval_duration_secs, limit=3):
+ def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
counters = []
for name, (count, cum_time) in self.current_counters.items():
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
@@ -279,6 +282,9 @@ class PerformanceCounters(object):
return top_n_counters
+R = TypeVar("R")
+
+
class DatabasePool(object):
"""Wraps a single physical database and connection pool.
@@ -327,12 +333,12 @@ class DatabasePool(object):
self._check_safe_to_upsert,
)
- def is_running(self):
+ def is_running(self) -> bool:
"""Is the database pool currently running
"""
return self._db_pool.running
- async def _check_safe_to_upsert(self):
+ async def _check_safe_to_upsert(self) -> None:
"""
Is it safe to use native UPSERT?
@@ -363,7 +369,7 @@ class DatabasePool(object):
self._check_safe_to_upsert,
)
- def start_profiling(self):
+ def start_profiling(self) -> None:
self._previous_loop_ts = monotonic_time()
def loop():
@@ -387,8 +393,15 @@ class DatabasePool(object):
self._clock.looping_call(loop, 10000)
def new_transaction(
- self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
- ):
+ self,
+ conn: Connection,
+ desc: str,
+ after_callbacks: List[_CallbackListEntry],
+ exception_callbacks: List[_CallbackListEntry],
+ func: "Callable[..., R]",
+ *args: Any,
+ **kwargs: Any
+ ) -> R:
start = monotonic_time()
txn_id = self._TXN_ID
@@ -537,7 +550,9 @@ class DatabasePool(object):
return result
- async def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
+ async def runWithConnection(
+ self, func: "Callable[..., R]", *args: Any, **kwargs: Any
+ ) -> R:
"""Wraps the .runWithConnection() method on the underlying db_pool.
Arguments:
@@ -576,11 +591,11 @@ class DatabasePool(object):
)
@staticmethod
- def cursor_to_dict(cursor):
+ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
"""Converts a SQL cursor into an list of dicts.
Args:
- cursor : The DBAPI cursor which has executed a query.
+ cursor: The DBAPI cursor which has executed a query.
Returns:
A list of dicts where the key is the column header.
"""
@@ -588,7 +603,7 @@ class DatabasePool(object):
results = [dict(zip(col_headers, row)) for row in cursor]
return results
- def execute(self, desc, decoder, query, *args):
+ def execute(self, desc: str, decoder: Callable, query: str, *args: Any):
"""Runs a single query for a result set.
Args:
@@ -597,7 +612,7 @@ class DatabasePool(object):
query - The query string to execute
*args - Query args.
Returns:
- The result of decoder(results)
+ Deferred which results to the result of decoder(results)
"""
def interaction(txn):
@@ -612,20 +627,25 @@ class DatabasePool(object):
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
- async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+ async def simple_insert(
+ self,
+ table: str,
+ values: Dict[str, Any],
+ or_ignore: bool = False,
+ desc: str = "simple_insert",
+ ) -> bool:
"""Executes an INSERT query on the named table.
Args:
- table : string giving the table name
- values : dict of new column names and values for them
- or_ignore : bool stating whether an exception should be raised
+ table: string giving the table name
+ values: dict of new column names and values for them
+ or_ignore: bool stating whether an exception should be raised
when a conflicting row already exists. If True, False will be
returned by the function instead
- desc : string giving a description of the transaction
+ desc: string giving a description of the transaction
Returns:
- bool: Whether the row was inserted or not. Only useful when
- `or_ignore` is True
+ Whether the row was inserted or not. Only useful when `or_ignore` is True
"""
try:
await self.runInteraction(desc, self.simple_insert_txn, table, values)
@@ -638,7 +658,9 @@ class DatabasePool(object):
return True
@staticmethod
- def simple_insert_txn(txn, table, values):
+ def simple_insert_txn(
+ txn: LoggingTransaction, table: str, values: Dict[str, Any]
+ ) -> None:
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -649,11 +671,15 @@ class DatabasePool(object):
txn.execute(sql, vals)
- def simple_insert_many(self, table, values, desc):
+ def simple_insert_many(
+ self, table: str, values: List[Dict[str, Any]], desc: str
+ ) -> defer.Deferred:
return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
@staticmethod
- def simple_insert_many_txn(txn, table, values):
+ def simple_insert_many_txn(
+ txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
+ ) -> None:
if not values:
return
@@ -683,13 +709,13 @@ class DatabasePool(object):
async def simple_upsert(
self,
- table,
- keyvalues,
- values,
- insertion_values={},
- desc="simple_upsert",
- lock=True,
- ):
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ desc: str = "simple_upsert",
+ lock: bool = True,
+ ) -> Optional[bool]:
"""
`lock` should generally be set to True (the default), but can be set
@@ -703,16 +729,14 @@ class DatabasePool(object):
this table.
Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key columns and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
+ table: The table to upsert into
+ keyvalues: The unique key columns and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
+ lock: True to lock the table when doing the upsert.
Returns:
- None or bool: Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
+ Native upserts always return None. Emulated upserts return True if a
+ new entry was created, False if an existing one was updated.
"""
attempts = 0
while True:
@@ -739,29 +763,34 @@ class DatabasePool(object):
)
def simple_upsert_txn(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ lock: bool = True,
+ ) -> Optional[bool]:
"""
Pick the UPSERT method which works best on the platform. Either the
native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
Args:
txn: The transaction to use.
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
+ table: The table to upsert into
+ keyvalues: The unique key tables and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
+ lock: True to lock the table when doing the upsert.
Returns:
- None or bool: Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
+ Native upserts always return None. Emulated upserts return True if a
+ new entry was created, False if an existing one was updated.
"""
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
- return self.simple_upsert_txn_native_upsert(
+ self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
)
+ return None
else:
return self.simple_upsert_txn_emulated(
txn,
@@ -773,18 +802,23 @@ class DatabasePool(object):
)
def simple_upsert_txn_emulated(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ lock: bool = True,
+ ) -> bool:
"""
Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
+ table: The table to upsert into
+ keyvalues: The unique key tables and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
+ lock: True to lock the table when doing the upsert.
Returns:
- bool: Return True if a new entry was created, False if an existing
+ Returns True if a new entry was created, False if an existing
one was updated.
"""
# We need to lock the table :(, unless we're *really* careful
@@ -842,19 +876,21 @@ class DatabasePool(object):
return True
def simple_upsert_txn_native_upsert(
- self, txn, table, keyvalues, values, insertion_values={}
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ ) -> None:
"""
Use the native UPSERT functionality in recent PostgreSQL versions.
Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- Returns:
- None
+ table: The table to upsert into
+ keyvalues: The unique key tables and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
"""
allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
@@ -985,18 +1021,22 @@ class DatabasePool(object):
return txn.execute_batch(sql, args)
def simple_select_one(
- self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
- ):
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: bool = False,
+ desc: str = "simple_select_one",
+ ) -> defer.Deferred:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcols : list of strings giving the names of the columns to return
-
- allow_none : If true, return None instead of failing if the SELECT
- statement returns no rows
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ retcols: list of strings giving the names of the columns to return
+ allow_none: If true, return None instead of failing if the SELECT
+ statement returns no rows
"""
return self.runInteraction(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
@@ -1004,19 +1044,22 @@ class DatabasePool(object):
def simple_select_one_onecol(
self,
- table,
- keyvalues,
- retcol,
- allow_none=False,
- desc="simple_select_one_onecol",
- ):
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ allow_none: bool = False,
+ desc: str = "simple_select_one_onecol",
+ ) -> defer.Deferred:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcol : string giving the name of the column to return
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ retcol: string giving the name of the column to return
+ allow_none: If true, return None instead of failing if the SELECT
+ statement returns no rows
+ desc: description of the transaction, for logging and metrics
"""
return self.runInteraction(
desc,
@@ -1029,8 +1072,13 @@ class DatabasePool(object):
@classmethod
def simple_select_one_onecol_txn(
- cls, txn, table, keyvalues, retcol, allow_none=False
- ):
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ allow_none: bool = False,
+ ) -> Optional[Any]:
ret = cls.simple_select_onecol_txn(
txn, table=table, keyvalues=keyvalues, retcol=retcol
)
@@ -1044,7 +1092,12 @@ class DatabasePool(object):
raise StoreError(404, "No row found")
@staticmethod
- def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+ def simple_select_onecol_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ ) -> List[Any]:
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
if keyvalues:
@@ -1056,15 +1109,19 @@ class DatabasePool(object):
return [r[0] for r in txn]
def simple_select_onecol(
- self, table, keyvalues, retcol, desc="simple_select_onecol"
- ):
+ self,
+ table: str,
+ keyvalues: Optional[Dict[str, Any]],
+ retcol: str,
+ desc: str = "simple_select_onecol",
+ ) -> defer.Deferred:
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
Args:
- table (str): table name
- keyvalues (dict|None): column names and values to select the rows with
- retcol (str): column whos value we wish to retrieve.
+ table: table name
+ keyvalues: column names and values to select the rows with
+ retcol: column whos value we wish to retrieve.
Returns:
Deferred: Results in a list
@@ -1073,16 +1130,22 @@ class DatabasePool(object):
desc, self.simple_select_onecol_txn, table, keyvalues, retcol
)
- def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+ def simple_select_list(
+ self,
+ table: str,
+ keyvalues: Optional[Dict[str, Any]],
+ retcols: Iterable[str],
+ desc: str = "simple_select_list",
+ ) -> defer.Deferred:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- table (str): the table name
- keyvalues (dict[str, Any] | None):
+ table: the table name
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
+ retcols: the names of the columns to return
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
@@ -1091,17 +1154,23 @@ class DatabasePool(object):
)
@classmethod
- def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+ def simple_select_list_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Optional[Dict[str, Any]],
+ retcols: Iterable[str],
+ ) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- txn : Transaction object
- table (str): the table name
- keyvalues (dict[str, T] | None):
+ txn: Transaction object
+ table: the table name
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
+ retcols: the names of the columns to return
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1118,25 +1187,25 @@ class DatabasePool(object):
async def simple_select_many_batch(
self,
- table,
- column,
- iterable,
- retcols,
- keyvalues={},
- desc="simple_select_many_batch",
- batch_size=100,
- ):
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ retcols: Iterable[str],
+ keyvalues: Dict[str, Any] = {},
+ desc: str = "simple_select_many_batch",
+ batch_size: int = 100,
+ ) -> List[Any]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Filters rows by if value of `column` is in `iterable`.
Args:
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ keyvalues: dict of column names and values to select the rows with
+ retcols: list of strings giving the names of the columns to return
"""
results = [] # type: List[Dict[str, Any]]
@@ -1165,19 +1234,27 @@ class DatabasePool(object):
return results
@classmethod
- def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+ def simple_select_many_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ ) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Filters rows by if value of `column` is in `iterable`.
Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
+ txn: Transaction object
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ keyvalues: dict of column names and values to select the rows with
+ retcols: list of strings giving the names of the columns to return
"""
if not iterable:
return []
@@ -1198,13 +1275,24 @@ class DatabasePool(object):
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
- def simple_update(self, table, keyvalues, updatevalues, desc):
+ def simple_update(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ desc: str,
+ ) -> defer.Deferred:
return self.runInteraction(
desc, self.simple_update_txn, table, keyvalues, updatevalues
)
@staticmethod
- def simple_update_txn(txn, table, keyvalues, updatevalues):
+ def simple_update_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ ) -> int:
if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
else:
@@ -1221,31 +1309,32 @@ class DatabasePool(object):
return txn.rowcount
def simple_update_one(
- self, table, keyvalues, updatevalues, desc="simple_update_one"
- ):
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ desc: str = "simple_update_one",
+ ) -> defer.Deferred:
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- updatevalues : dict giving column names and values to update
- retcols : optional list of column names to return
-
- If present, retcols gives a list of column names on which to perform
- a SELECT statement *before* performing the UPDATE statement. The values
- of these will be returned in a dict.
-
- These are performed within the same transaction, allowing an atomic
- get-and-set. This can be used to implement compare-and-set by putting
- the update column in the 'keyvalues' dict as well.
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ updatevalues: dict giving column names and values to update
"""
return self.runInteraction(
desc, self.simple_update_one_txn, table, keyvalues, updatevalues
)
@classmethod
- def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+ def simple_update_one_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ ) -> None:
rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
if rowcount == 0:
@@ -1253,8 +1342,18 @@ class DatabasePool(object):
if rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
+ # Ideally we could use the overload decorator here to specify that the
+ # return type is only optional if allow_none is True, but this does not work
+ # when you call a static method from an instance.
+ # See https://github.com/python/mypy/issues/7781
@staticmethod
- def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+ def simple_select_one_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: bool = False,
+ ) -> Optional[Dict[str, Any]]:
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
@@ -1273,24 +1372,28 @@ class DatabasePool(object):
return dict(zip(retcols, row))
- def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+ def simple_delete_one(
+ self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
+ ) -> defer.Deferred:
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
"""
return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
@staticmethod
- def simple_delete_one_txn(txn, table, keyvalues):
+ def simple_delete_one_txn(
+ txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+ ) -> None:
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
@@ -1303,11 +1406,13 @@ class DatabasePool(object):
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
- def simple_delete(self, table, keyvalues, desc):
+ def simple_delete(self, table: str, keyvalues: Dict[str, Any], desc: str):
return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
@staticmethod
- def simple_delete_txn(txn, table, keyvalues):
+ def simple_delete_txn(
+ txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+ ) -> int:
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k,) for k in keyvalues),
@@ -1316,26 +1421,39 @@ class DatabasePool(object):
txn.execute(sql, list(keyvalues.values()))
return txn.rowcount
- def simple_delete_many(self, table, column, iterable, keyvalues, desc):
+ def simple_delete_many(
+ self,
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ keyvalues: Dict[str, Any],
+ desc: str,
+ ) -> defer.Deferred:
return self.runInteraction(
desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
)
@staticmethod
- def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+ def simple_delete_many_txn(
+ txn: LoggingTransaction,
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ keyvalues: Dict[str, Any],
+ ) -> int:
"""Executes a DELETE query on the named table.
Filters rows by if value of `column` is in `iterable`.
Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
+ txn: Transaction object
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ keyvalues: dict of column names and values to select the rows with
Returns:
- int: Number rows deleted
+ Number rows deleted
"""
if not iterable:
return 0
@@ -1356,8 +1474,14 @@ class DatabasePool(object):
return txn.rowcount
def get_cache_dict(
- self, db_conn, table, entity_column, stream_column, max_value, limit=100000
- ):
+ self,
+ db_conn: Connection,
+ table: str,
+ entity_column: str,
+ stream_column: str,
+ max_value: int,
+ limit: int = 100000,
+ ) -> Tuple[Dict[Any, int], int]:
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
@@ -1390,34 +1514,34 @@ class DatabasePool(object):
def simple_select_list_paginate(
self,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=None,
- keyvalues=None,
- order_direction="ASC",
- desc="simple_select_list_paginate",
- ):
+ table: str,
+ orderby: str,
+ start: int,
+ limit: int,
+ retcols: Iterable[str],
+ filters: Optional[Dict[str, Any]] = None,
+ keyvalues: Optional[Dict[str, Any]] = None,
+ order_direction: str = "ASC",
+ desc: str = "simple_select_list_paginate",
+ ) -> defer.Deferred:
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
Args:
- table (str): the table name
- filters (dict[str, T] | None):
+ table: the table name
+ orderby: Column to order the results by.
+ start: Index to begin the query at.
+ limit: Number of results to return.
+ retcols: the names of the columns to return
+ filters:
column names and values to filter the rows with, or None to not
apply a WHERE ? LIKE ? clause.
- keyvalues (dict[str, T] | None):
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+ order_direction: Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
@@ -1437,16 +1561,16 @@ class DatabasePool(object):
@classmethod
def simple_select_list_paginate_txn(
cls,
- txn,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=None,
- keyvalues=None,
- order_direction="ASC",
- ):
+ txn: LoggingTransaction,
+ table: str,
+ orderby: str,
+ start: int,
+ limit: int,
+ retcols: Iterable[str],
+ filters: Optional[Dict[str, Any]] = None,
+ keyvalues: Optional[Dict[str, Any]] = None,
+ order_direction: str = "ASC",
+ ) -> List[Dict[str, Any]]:
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
@@ -1457,21 +1581,22 @@ class DatabasePool(object):
using 'AND'.
Args:
- txn : Transaction object
- table (str): the table name
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- filters (dict[str, T] | None):
+ txn: Transaction object
+ table: the table name
+ orderby: Column to order the results by.
+ start: Index to begin the query at.
+ limit: Number of results to return.
+ retcols: the names of the columns to return
+ filters:
column names and values to filter the rows with, or None to not
apply a WHERE ? LIKE ? clause.
- keyvalues (dict[str, T] | None):
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+ order_direction: Whether the results should be ordered "ASC" or "DESC".
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ The result as a list of dictionaries.
"""
if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
@@ -1497,16 +1622,23 @@ class DatabasePool(object):
return cls.cursor_to_dict(txn)
- def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+ def simple_search_list(
+ self,
+ table: str,
+ term: Optional[str],
+ col: str,
+ retcols: Iterable[str],
+ desc="simple_search_list",
+ ):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
+ table: the table name
+ term: term for searching the table matched to a column.
+ col: column to query term should be matched to
+ retcols: the names of the columns to return
+
Returns:
defer.Deferred: resolves to list[dict[str, Any]] or None
"""
@@ -1516,19 +1648,26 @@ class DatabasePool(object):
)
@classmethod
- def simple_search_list_txn(cls, txn, table, term, col, retcols):
+ def simple_search_list_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ term: Optional[str],
+ col: str,
+ retcols: Iterable[str],
+ ) -> Union[List[Dict[str, Any]], int]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- txn : Transaction object
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
+ txn: Transaction object
+ table: the table name
+ term: term for searching the table matched to a column.
+ col: column to query term should be matched to
+ retcols: the names of the columns to return
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
+ 0 if no term is given, otherwise a list of dictionaries.
"""
if term:
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
@@ -1541,7 +1680,7 @@ class DatabasePool(object):
def make_in_list_sql_clause(
- database_engine, column: str, iterable: Iterable
+ database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given column is in the iterable.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 17fa470919..0934ae276c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -498,7 +498,7 @@ class DataStore(
)
def get_users_paginate(
- self, start, limit, name=None, guests=True, deactivated=False
+ self, start, limit, user_id=None, name=None, guests=True, deactivated=False
):
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
@@ -507,7 +507,8 @@ class DataStore(
Args:
start (int): start number to begin the query from
limit (int): number of rows to retrieve
- name (string): filter for user names
+ user_id (string): search for user_id. ignored if name is not None
+ name (string): search for local part of user_id or display name
guests (bool): whether to in include guest users
deactivated (bool): whether to include deactivated users
Returns:
@@ -516,11 +517,14 @@ class DataStore(
def get_users_paginate_txn(txn):
filters = []
- args = []
+ args = [self.hs.config.server_name]
if name:
+ filters.append("(name LIKE ? OR displayname LIKE ?)")
+ args.extend(["@%" + name + "%:%", "%" + name + "%"])
+ elif user_id:
filters.append("name LIKE ?")
- args.append("%" + name + "%")
+ args.extend(["%" + user_id + "%"])
if not guests:
filters.append("is_guest = 0")
@@ -530,20 +534,23 @@ class DataStore(
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
- sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
- txn.execute(sql, args)
- count = txn.fetchone()[0]
-
- args = [self.hs.config.server_name] + args + [limit, start]
- sql = """
- SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
+ sql_base = """
FROM users as u
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
{}
- ORDER BY u.name LIMIT ? OFFSET ?
""".format(
where_clause
)
+ sql = "SELECT COUNT(*) as total_users " + sql_base
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = (
+ "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
+ + sql_base
+ + " ORDER BY u.name LIMIT ? OFFSET ?"
+ )
+ args += [limit, start]
txn.execute(sql, args)
users = self.db_pool.cursor_to_dict(txn)
return users, count
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 82aac2bbf3..04042a2c98 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
@@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 02568a2391..77723f7d4d 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -16,13 +16,12 @@
import logging
import re
-from canonicaljson import json
-
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util import json_encoder
logger = logging.getLogger(__name__)
@@ -204,7 +203,7 @@ class ApplicationServiceTransactionWorkerStore(
new_txn_id = max(highest_txn_id, last_txn_id) + 1
# Insert new txn into txn table
- event_ids = json.dumps([e.event_id for e in events])
+ event_ids = json_encoder.encode([e.event_id for e in events])
txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1f6e995c4f..bb85637a95 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
- with self._device_inbox_id_gen.get_next() as stream_id:
+ with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device
)
- with self._device_inbox_id_gen.get_next() as stream_id:
+ with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9a786e2929..03b45dbc4d 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID.
"""
- with self._device_list_id_gen.get_next() as stream_id:
+ with await self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -1146,7 +1146,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return
- with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+ with await self._device_list_id_gen.get_next_mult(
+ len(device_ids)
+ ) as stream_ids:
await self.db_pool.runInteraction(
"add_device_change_to_stream",
self._add_device_change_to_stream_txn,
@@ -1159,7 +1161,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
context = get_active_span_text_map()
- with self._device_list_id_gen.get_next_mult(
+ with await self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f93e0d320d..385868bdab 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
- def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
+ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
"""Set a user's cross-signing key.
Args:
@@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
key (dict): the key data
+ stream_id (int)
"""
# the 'key' dict will look something like:
# {
@@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
# and finally, store the key itself
- with self._cross_signing_id_gen.get_next() as stream_id:
- self.db_pool.simple_insert_txn(
- txn,
- "e2e_cross_signing_keys",
- values={
- "user_id": user_id,
- "keytype": key_type,
- "keydata": json_encoder.encode(key),
- "stream_id": stream_id,
- },
- )
+ self.db_pool.simple_insert_txn(
+ txn,
+ "e2e_cross_signing_keys",
+ values={
+ "user_id": user_id,
+ "keytype": key_type,
+ "keydata": json_encoder.encode(key),
+ "stream_id": stream_id,
+ },
+ )
self._invalidate_cache_and_stream(
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
)
- def set_e2e_cross_signing_key(self, user_id, key_type, key):
+ async def set_e2e_cross_signing_key(self, user_id, key_type, key):
"""Set a user's cross-signing key.
Args:
@@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key_type (str): the type of cross-signing key to set
key (dict): the key data
"""
- return self.db_pool.runInteraction(
- "add_e2e_cross_signing_key",
- self._set_e2e_cross_signing_key_txn,
- user_id,
- key_type,
- key,
- )
+
+ with await self._cross_signing_id_gen.get_next() as stream_id:
+ return await self.db_pool.runInteraction(
+ "add_e2e_cross_signing_key",
+ self._set_e2e_cross_signing_key_txn,
+ user_id,
+ key_type,
+ key,
+ stream_id,
+ )
def store_e2e_cross_signing_signatures(self, user_id, signatures):
"""Stores cross-signing signatures.
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4826be630c..e6a97b018c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -15,14 +15,16 @@
import itertools
import logging
from queue import Empty, PriorityQueue
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Set, Tuple
from synapse.api.errors import StoreError
+from synapse.events import EventBase
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.types import Collection
from synapse.util.caches.descriptors import cached
from synapse.util.iterutils import batch_iter
@@ -30,12 +32,14 @@ logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
- async def get_auth_chain(self, event_ids, include_given=False):
+ async def get_auth_chain(
+ self, event_ids: Collection[str], include_given: bool = False
+ ) -> List[EventBase]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
- event_ids (list): state events
- include_given (bool): include the given events in result
+ event_ids: state events
+ include_given: include the given events in result
Returns:
list of events
@@ -45,43 +49,34 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
return await self.get_events_as_list(event_ids)
- def get_auth_chain_ids(
- self,
- event_ids: List[str],
- include_given: bool = False,
- ignore_events: Optional[Set[str]] = None,
- ):
+ async def get_auth_chain_ids(
+ self, event_ids: Collection[str], include_given: bool = False,
+ ) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
event_ids: state events
include_given: include the given events in result
- ignore_events: Set of events to exclude from the returned auth
- chain. This is useful if the caller will just discard the
- given events anyway, and saves us from figuring out their auth
- chains if not required.
Returns:
list of event_ids
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
event_ids,
include_given,
- ignore_events,
)
- def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
- if ignore_events is None:
- ignore_events = set()
-
+ def _get_auth_chain_ids_txn(
+ self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
+ ) -> List[str]:
if include_given:
results = set(event_ids)
else:
results = set()
- base_sql = "SELECT auth_id FROM event_auth WHERE "
+ base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
front = set(event_ids)
while front:
@@ -93,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(base_sql + clause, args)
new_front.update(r[0] for r in txn)
- new_front -= ignore_events
new_front -= results
front = new_front
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b90e6de2d5..6313b41eef 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -153,11 +153,11 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
if backfilled:
- stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+ stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
- stream_ordering_manager = self._stream_id_gen.get_next_mult(
+ stream_ordering_manager = await self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4a3333c0db..e1241a724b 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -620,19 +620,38 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row["room_version_id"]
if not room_version_id:
- # this should only happen for out-of-band membership events
- if not internal_metadata.get("out_of_band_membership"):
- logger.warning(
- "Room %s for event %s is unknown", d["room_id"], event_id
+ # this should only happen for out-of-band membership events which
+ # arrived before #6983 landed. For all other events, we should have
+ # an entry in the 'rooms' table.
+ #
+ # However, the 'out_of_band_membership' flag is unreliable for older
+ # invites, so just accept it for all membership events.
+ #
+ if d["type"] != EventTypes.Member:
+ raise Exception(
+ "Room %s for event %s is unknown" % (d["room_id"], event_id)
)
- continue
- # take a wild stab at the room version based on the event format
+ # so, assuming this is an out-of-band-invite that arrived before #6983
+ # landed, we know that the room version must be v5 or earlier (because
+ # v6 hadn't been invented at that point, so invites from such rooms
+ # would have been rejected.)
+ #
+ # The main reason we need to know the room version here (other than
+ # choosing the right python Event class) is in case the event later has
+ # to be redacted - and all the room versions up to v5 used the same
+ # redaction algorithm.
+ #
+ # So, the following approximations should be adequate.
+
if format_version == EventFormatVersions.V1:
+ # if it's event format v1 then it must be room v1 or v2
room_version = RoomVersions.V1
elif format_version == EventFormatVersions.V2:
+ # if it's event format v2 then it must be room v3
room_version = RoomVersions.V3
else:
+ # if it's event format v3 then it must be room v4 or v5
room_version = RoomVersions.V5
else:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 0e3b8739c6..a488e0924b 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1182,7 +1182,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- with self._group_updates_id_gen.get_next() as next_id:
+ with await self._group_updates_id_gen.get_next() as next_id:
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 4e3ec02d14..c9f655dfb7 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
async def update_presence(self, presence_states):
- stream_ordering_manager = self._presence_id_gen.get_next_mult(
+ stream_ordering_manager = await self._presence_id_gen.get_next_mult(
len(presence_states)
)
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index a585e54812..2fb5b02d7d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
- with self._push_rules_stream_id_gen.get_next() as stream_id:
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
if before or after:
@@ -560,7 +560,7 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
- with self._push_rules_stream_id_gen.get_next() as stream_id:
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
@@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore):
)
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
- with self._push_rules_stream_id_gen.get_next() as stream_id:
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
@@ -646,7 +646,7 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
- with self._push_rules_stream_id_gen.get_next() as stream_id:
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 1126fd0751..c388468273 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
last_stream_ordering,
profile_tag="",
) -> None:
- with self._pushers_id_gen.get_next() as stream_id:
+ with await self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
},
)
- with self._pushers_id_gen.get_next() as stream_id:
+ with await self._pushers_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 19ad1c056f..6821476ee0 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -520,8 +520,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"insert_receipt_conv", graph_to_linear
)
- stream_id_manager = self._receipts_id_gen.get_next()
- with stream_id_manager as stream_id:
+ with await self._receipts_id_gen.get_next() as stream_id:
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 5986d32b18..336b578e23 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -968,6 +968,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
super(RegistrationStore, self).__init__(database, db_conn, hs)
self._account_validity = hs.config.account_validity
+ self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
if self._account_validity.enabled:
self._clock.call_later(
@@ -1381,15 +1382,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
if not row:
- raise ThreepidValidationError(400, "Unknown session_id")
+ if self._ignore_unknown_session_error:
+ # If we need to inhibit the error caused by an incorrect session ID,
+ # use None as placeholder values for the client secret and the
+ # validation timestamp.
+ # It shouldn't be an issue because they're both only checked after
+ # the token check, which should fail. And if it doesn't for some
+ # reason, the next check is on the client secret, which is NOT NULL,
+ # so we don't have to worry about the client secret matching by
+ # accident.
+ row = {"client_secret": None, "validated_at": None}
+ else:
+ raise ThreepidValidationError(400, "Unknown session_id")
+
retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"]
- if retrieved_client_secret != client_secret:
- raise ThreepidValidationError(
- 400, "This client_secret does not match the provided session_id"
- )
-
row = self.db_pool.simple_select_one_txn(
txn,
table="threepid_validation_token",
@@ -1405,6 +1413,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
expires = row["expires"]
next_link = row["next_link"]
+ if retrieved_client_secret != client_secret:
+ raise ThreepidValidationError(
+ 400, "This client_secret does not match the provided session_id"
+ )
+
# If the session is already validated, no need to revalidate
if validated_at:
return next_link
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 0142a856d5..99a8a9fab0 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -21,10 +21,6 @@ from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
-from canonicaljson import json
-
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
@@ -32,6 +28,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchStore
from synapse.types import ThirdPartyInstanceID
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -342,23 +339,22 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
- @defer.inlineCallbacks
- def is_room_published(self, room_id):
+ async def is_room_published(self, room_id: str) -> bool:
"""Check whether a room has been published in the local public room
directory.
Args:
- room_id (str)
+ room_id
Returns:
- bool: Whether the room is currently published in the room directory
+ Whether the room is currently published in the room directory
"""
# Get room information
- room_info = yield self.get_room(room_id)
+ room_info = await self.get_room(room_id)
if not room_info:
- defer.returnValue(False)
+ return False
# Check the is_public value
- defer.returnValue(room_info.get("is_public", False))
+ return room_info.get("is_public", False)
async def get_rooms_paginate(
self,
@@ -572,7 +568,7 @@ class RoomWorkerStore(SQLBaseStore):
# maximum, in order not to filter out events we should filter out when sending to
# the client.
if not self.config.retention_enabled:
- defer.returnValue({"min_lifetime": None, "max_lifetime": None})
+ return {"min_lifetime": None, "max_lifetime": None}
def get_retention_policy_for_room_txn(txn):
txn.execute(
@@ -1155,7 +1151,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
+ with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id
)
@@ -1222,7 +1218,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
+ with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
@@ -1302,7 +1298,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
+ with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
@@ -1335,7 +1331,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
"event_id": event_id,
"user_id": user_id,
"reason": reason,
- "content": json.dumps(content),
+ "content": json_encoder.encode(content),
},
desc="add_event_report",
)
diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
new file mode 100644
index 0000000000..4cc96a5341
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A table of the IP address and user-agent used to complete each step of a
+-- user-interactive authentication session.
+CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips(
+ session_id TEXT NOT NULL,
+ ip TEXT NOT NULL,
+ user_agent TEXT NOT NULL,
+ UNIQUE (session_id, ip, user_agent),
+ FOREIGN KEY (session_id)
+ REFERENCES ui_auth_sessions (session_id)
+);
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index e4e0a0c433..0c34bbf21a 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -17,11 +17,10 @@
import logging
from typing import Dict, List, Tuple
-from canonicaljson import json
-
from synapse.storage._base import db_to_json
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.types import JsonDict
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -98,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (user_id, room_id))
tags = []
for tag, content in txn:
- tags.append(json.dumps(tag) + ":" + content)
+ tags.append(json_encoder.encode(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, (user_id, room_id, tag_json)))
@@ -200,7 +199,7 @@ class TagsStore(TagsWorkerStore):
Returns:
The next account data ID.
"""
- content_json = json.dumps(content)
+ content_json = json_encoder.encode(content)
def add_tag_txn(txn, next_id):
self.db_pool.simple_upsert_txn(
@@ -211,7 +210,7 @@ class TagsStore(TagsWorkerStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@@ -233,7 +232,7 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 37276f73f8..9eef8e57c5 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -12,15 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import attr
-from canonicaljson import json
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
-from synapse.util import stringutils as stringutils
+from synapse.util import json_encoder, stringutils
@attr.s
@@ -72,7 +72,7 @@ class UIAuthWorkerStore(SQLBaseStore):
StoreError if a unique session ID cannot be generated.
"""
# The clientdict gets stored as JSON.
- clientdict_json = json.dumps(clientdict)
+ clientdict_json = json_encoder.encode(clientdict)
# autogen a session ID and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
@@ -143,7 +143,7 @@ class UIAuthWorkerStore(SQLBaseStore):
await self.db_pool.simple_upsert(
table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id, "stage_type": stage_type},
- values={"result": json.dumps(result)},
+ values={"result": json_encoder.encode(result)},
desc="mark_ui_auth_stage_complete",
)
except self.db_pool.engine.module.IntegrityError:
@@ -184,7 +184,7 @@ class UIAuthWorkerStore(SQLBaseStore):
The dictionary from the client root level, not the 'auth' key.
"""
# The clientdict gets stored as JSON.
- clientdict_json = json.dumps(clientdict)
+ clientdict_json = json_encoder.encode(clientdict)
await self.db_pool.simple_update_one(
table="ui_auth_sessions",
@@ -214,14 +214,16 @@ class UIAuthWorkerStore(SQLBaseStore):
value,
)
- def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
+ def _set_ui_auth_session_data_txn(
+ self, txn: LoggingTransaction, session_id: str, key: str, value: Any
+ ):
# Get the current value.
result = self.db_pool.simple_select_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
- )
+ ) # type: Dict[str, Any] # type: ignore
# Update it and add it back to the database.
serverdict = db_to_json(result["serverdict"])
@@ -231,7 +233,7 @@ class UIAuthWorkerStore(SQLBaseStore):
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
- updatevalues={"serverdict": json.dumps(serverdict)},
+ updatevalues={"serverdict": json_encoder.encode(serverdict)},
)
async def get_ui_auth_session_data(
@@ -258,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore):
return serverdict.get(key, default)
+ async def add_user_agent_ip_to_ui_auth_session(
+ self, session_id: str, user_agent: str, ip: str,
+ ):
+ """Add the given user agent / IP to the tracking table
+ """
+ await self.db_pool.simple_upsert(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
+ values={},
+ desc="add_user_agent_ip_to_ui_auth_session",
+ )
+
+ async def get_user_agents_ips_to_ui_auth_session(
+ self, session_id: str,
+ ) -> List[Tuple[str, str]]:
+ """Get the given user agents / IPs used during the ui auth process
+
+ Returns:
+ List of user_agent/ip pairs
+ """
+ rows = await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id},
+ retcols=("user_agent", "ip"),
+ desc="get_user_agents_ips_to_ui_auth_session",
+ )
+ return [(row["user_agent"], row["ip"]) for row in rows]
+
class UIAuthStore(UIAuthWorkerStore):
def delete_old_ui_auth_sessions(self, expiration_time: int):
@@ -275,12 +305,23 @@ class UIAuthStore(UIAuthWorkerStore):
expiration_time,
)
- def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
+ def _delete_old_ui_auth_sessions_txn(
+ self, txn: LoggingTransaction, expiration_time: int
+ ):
# Get the expired sessions.
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
txn.execute(sql, [expiration_time])
session_ids = [r[0] for r in txn.fetchall()]
+ # Delete the corresponding IP/user agents.
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="ui_auth_sessions_ips",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={},
+ )
+
# Delete the corresponding completed credentials.
self.db_pool.simple_delete_many_txn(
txn,
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 0bf772d4d1..5b07847773 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -14,9 +14,10 @@
# limitations under the License.
import contextlib
+import heapq
import threading
from collections import deque
-from typing import Dict, Set
+from typing import Dict, List, Set
from typing_extensions import Deque
@@ -80,7 +81,7 @@ class StreamIdGenerator(object):
upwards, -1 to grow downwards.
Usage:
- with stream_id_gen.get_next() as stream_id:
+ with await stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@@ -95,10 +96,10 @@ class StreamIdGenerator(object):
)
self._unfinished_ids = deque() # type: Deque[int]
- def get_next(self):
+ async def get_next(self):
"""
Usage:
- with stream_id_gen.get_next() as stream_id:
+ with await stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -117,10 +118,10 @@ class StreamIdGenerator(object):
return manager()
- def get_next_mult(self, n):
+ async def get_next_mult(self, n):
"""
Usage:
- with stream_id_gen.get_next(n) as stream_ids:
+ with await stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -210,6 +211,23 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
+ # We track the max position where we know everything before has been
+ # persisted. This is done by a) looking at the min across all instances
+ # and b) noting that if we have seen a run of persisted positions
+ # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
+ #
+ # Note: There is no guarentee that the IDs generated by the sequence
+ # will be gapless; gaps can form when e.g. a transaction was rolled
+ # back. This means that sometimes we won't be able to skip forward the
+ # position even though everything has been persisted. However, since
+ # gaps should be relatively rare it's still worth doing the book keeping
+ # that allows us to skip forwards when there are gapless runs of
+ # positions.
+ self._persisted_upto_position = (
+ min(self._current_positions.values()) if self._current_positions else 0
+ )
+ self._known_persisted_positions = [] # type: List[int]
+
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
def _load_current_ids(
@@ -234,9 +252,12 @@ class MultiWriterIdGenerator:
return current_positions
- def _load_next_id_txn(self, txn):
+ def _load_next_id_txn(self, txn) -> int:
return self._sequence_gen.get_next_id_txn(txn)
+ def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
+ return self._sequence_gen.get_next_mult_txn(txn, n)
+
async def get_next(self):
"""
Usage:
@@ -262,6 +283,34 @@ class MultiWriterIdGenerator:
return manager()
+ async def get_next_mult(self, n: int):
+ """
+ Usage:
+ with await stream_id_gen.get_next_mult(5) as stream_ids:
+ # ... persist events ...
+ """
+ next_ids = await self._db.runInteraction(
+ "_load_next_mult_id", self._load_next_mult_id_txn, n
+ )
+
+ # Assert the fetched ID is actually greater than any ID we've already
+ # seen. If not, then the sequence and table have got out of sync
+ # somehow.
+ assert max(self.get_positions().values(), default=0) < min(next_ids)
+
+ with self._lock:
+ self._unfinished_ids.update(next_ids)
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield next_ids
+ finally:
+ for i in next_ids:
+ self._mark_id_as_finished(i)
+
+ return manager()
+
def get_next_txn(self, txn: LoggingTransaction):
"""
Usage:
@@ -326,3 +375,53 @@ class MultiWriterIdGenerator:
self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0)
)
+
+ self._add_persisted_position(new_id)
+
+ def get_persisted_upto_position(self) -> int:
+ """Get the max position where all previous positions have been
+ persisted.
+
+ Note: In the worst case scenario this will be equal to the minimum
+ position across writers. This means that the returned position here can
+ lag if one writer doesn't write very often.
+ """
+
+ with self._lock:
+ return self._persisted_upto_position
+
+ def _add_persisted_position(self, new_id: int):
+ """Record that we have persisted a position.
+
+ This is used to keep the `_current_positions` up to date.
+ """
+
+ # We require that the lock is locked by caller
+ assert self._lock.locked()
+
+ heapq.heappush(self._known_persisted_positions, new_id)
+
+ # We move the current min position up if the minimum current positions
+ # of all instances is higher (since by definition all positions less
+ # that that have been persisted).
+ min_curr = min(self._current_positions.values())
+ self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
+
+ # We now iterate through the seen positions, discarding those that are
+ # less than the current min positions, and incrementing the min position
+ # if its exactly one greater.
+ #
+ # This is also where we discard items from `_known_persisted_positions`
+ # (to ensure the list doesn't infinitely grow).
+ while self._known_persisted_positions:
+ if self._known_persisted_positions[0] <= self._persisted_upto_position:
+ heapq.heappop(self._known_persisted_positions)
+ elif (
+ self._known_persisted_positions[0] == self._persisted_upto_position + 1
+ ):
+ heapq.heappop(self._known_persisted_positions)
+ self._persisted_upto_position += 1
+ else:
+ # There was a gap in seen positions, so there is nothing more to
+ # do.
+ break
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 63dfea4220..ffc1894748 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -14,7 +14,7 @@
# limitations under the License.
import abc
import threading
-from typing import Callable, Optional
+from typing import Callable, List, Optional
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor
@@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
txn.execute("SELECT nextval(?)", (self._sequence_name,))
return txn.fetchone()[0]
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ txn.execute(
+ "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
+ )
+ return [i for (i,) in txn]
+
GetFirstCallbackType = Callable[[Cursor], int]
|