diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index b66347a79f..1b1e46a64c 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -28,6 +28,7 @@ from typing import (
Sequence,
Tuple,
Type,
+ cast,
)
import attr
@@ -492,14 +493,14 @@ class BackgroundUpdater:
True if we have finished running all the background updates, otherwise False
"""
- def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]:
+ def get_background_updates_txn(txn: Cursor) -> List[Tuple[str, Optional[str]]]:
txn.execute(
"""
SELECT update_name, depends_on FROM background_updates
ORDER BY ordering, update_name
"""
)
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, Optional[str]]], txn.fetchall())
if not self._current_background_update:
all_pending_updates = await self.db_pool.runInteraction(
@@ -511,14 +512,13 @@ class BackgroundUpdater:
return True
# find the first update which isn't dependent on another one in the queue.
- pending = {update["update_name"] for update in all_pending_updates}
- for upd in all_pending_updates:
- depends_on = upd["depends_on"]
+ pending = {update_name for update_name, depends_on in all_pending_updates}
+ for update_name, depends_on in all_pending_updates:
if not depends_on or depends_on not in pending:
break
logger.info(
"Not starting on bg update %s until %s is done",
- upd["update_name"],
+ update_name,
depends_on,
)
else:
@@ -528,7 +528,7 @@ class BackgroundUpdater:
"another: dependency cycle?"
)
- self._current_background_update = upd["update_name"]
+ self._current_background_update = update_name
# We have a background update to run, otherwise we would have returned
# early.
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index f39ae2d635..1529c86cc5 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -542,13 +542,15 @@ class EventsPersistenceStorageController:
return await res.get_state(self._state_controller, StateFilter.all())
async def _persist_event_batch(
- self, _room_id: str, task: _PersistEventsTask
+ self, room_id: str, task: _PersistEventsTask
) -> Dict[str, str]:
"""Callback for the _event_persist_queue
Calculates the change to current state and forward extremities, and
persists the given events and with those updates.
+ Assumes that we are only persisting events for one room at a time.
+
Returns:
A dictionary of event ID to event ID we didn't persist as we already
had another event persisted with the same TXN ID.
@@ -594,140 +596,23 @@ class EventsPersistenceStorageController:
# We can't easily parallelize these since different chunks
# might contain the same event. :(
- # NB: Assumes that we are only persisting events for one room
- # at a time.
-
- # map room_id->set[event_ids] giving the new forward
- # extremities in each room
- new_forward_extremities: Dict[str, Set[str]] = {}
-
- # map room_id->(to_delete, to_insert) where to_delete is a list
- # of type/state keys to remove from current state, and to_insert
- # is a map (type,key)->event_id giving the state delta in each
- # room
- state_delta_for_room: Dict[str, DeltaState] = {}
+ new_forward_extremities = None
+ state_delta_for_room = None
if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"):
- # Work out the new "current state" for each room.
+ # Work out the new "current state" for the room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
- events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
- for event, context in chunk:
- events_by_room.setdefault(event.room_id, []).append(
- (event, context)
- )
-
- for room_id, ev_ctx_rm in events_by_room.items():
- latest_event_ids = (
- await self.main_store.get_latest_event_ids_in_room(room_id)
- )
- new_latest_event_ids = await self._calculate_new_extremities(
- room_id, ev_ctx_rm, latest_event_ids
- )
-
- if new_latest_event_ids == latest_event_ids:
- # No change in extremities, so no change in state
- continue
-
- # there should always be at least one forward extremity.
- # (except during the initial persistence of the send_join
- # results, in which case there will be no existing
- # extremities, so we'll `continue` above and skip this bit.)
- assert new_latest_event_ids, "No forward extremities left!"
-
- new_forward_extremities[room_id] = new_latest_event_ids
-
- len_1 = (
- len(latest_event_ids) == 1
- and len(new_latest_event_ids) == 1
- )
- if len_1:
- all_single_prev_not_state = all(
- len(event.prev_event_ids()) == 1
- and not event.is_state()
- for event, ctx in ev_ctx_rm
- )
- # Don't bother calculating state if they're just
- # a long chain of single ancestor non-state events.
- if all_single_prev_not_state:
- continue
-
- state_delta_counter.inc()
- if len(new_latest_event_ids) == 1:
- state_delta_single_event_counter.inc()
-
- # This is a fairly handwavey check to see if we could
- # have guessed what the delta would have been when
- # processing one of these events.
- # What we're interested in is if the latest extremities
- # were the same when we created the event as they are
- # now. When this server creates a new event (as opposed
- # to receiving it over federation) it will use the
- # forward extremities as the prev_events, so we can
- # guess this by looking at the prev_events and checking
- # if they match the current forward extremities.
- for ev, _ in ev_ctx_rm:
- prev_event_ids = set(ev.prev_event_ids())
- if latest_event_ids == prev_event_ids:
- state_delta_reuse_delta_counter.inc()
- break
-
- logger.debug("Calculating state delta for room %s", room_id)
- with Measure(
- self._clock, "persist_events.get_new_state_after_events"
- ):
- res = await self._get_new_state_after_events(
- room_id,
- ev_ctx_rm,
- latest_event_ids,
- new_latest_event_ids,
- )
- current_state, delta_ids, new_latest_event_ids = res
-
- # there should always be at least one forward extremity.
- # (except during the initial persistence of the send_join
- # results, in which case there will be no existing
- # extremities, so we'll `continue` above and skip this bit.)
- assert new_latest_event_ids, "No forward extremities left!"
-
- new_forward_extremities[room_id] = new_latest_event_ids
-
- # If either are not None then there has been a change,
- # and we need to work out the delta (or use that
- # given)
- delta = None
- if delta_ids is not None:
- # If there is a delta we know that we've
- # only added or replaced state, never
- # removed keys entirely.
- delta = DeltaState([], delta_ids)
- elif current_state is not None:
- with Measure(
- self._clock, "persist_events.calculate_state_delta"
- ):
- delta = await self._calculate_state_delta(
- room_id, current_state
- )
-
- if delta:
- # If we have a change of state then lets check
- # whether we're actually still a member of the room,
- # or if our last user left. If we're no longer in
- # the room then we delete the current state and
- # extremities.
- is_still_joined = await self._is_server_still_joined(
- room_id,
- ev_ctx_rm,
- delta,
- )
- if not is_still_joined:
- logger.info("Server no longer in room %s", room_id)
- delta.no_longer_in_room = True
-
- state_delta_for_room[room_id] = delta
+ (
+ new_forward_extremities,
+ state_delta_for_room,
+ ) = await self._calculate_new_forward_extremities_and_state_delta(
+ room_id, chunk
+ )
await self.persist_events_store._persist_events_and_state_updates(
+ room_id,
chunk,
state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities,
@@ -737,6 +622,117 @@ class EventsPersistenceStorageController:
return replaced_events
+ async def _calculate_new_forward_extremities_and_state_delta(
+ self, room_id: str, ev_ctx_rm: List[Tuple[EventBase, EventContext]]
+ ) -> Tuple[Optional[Set[str]], Optional[DeltaState]]:
+ """Calculates the new forward extremities and state delta for a room
+ given events to persist.
+
+ Assumes that we are only persisting events for one room at a time.
+
+ Returns:
+ A tuple of:
+ A set of str giving the new forward extremities the room
+
+ The state delta for the room.
+ """
+
+ latest_event_ids = await self.main_store.get_latest_event_ids_in_room(room_id)
+ new_latest_event_ids = await self._calculate_new_extremities(
+ room_id, ev_ctx_rm, latest_event_ids
+ )
+
+ if new_latest_event_ids == latest_event_ids:
+ # No change in extremities, so no change in state
+ return (None, None)
+
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
+ new_forward_extremities = new_latest_event_ids
+
+ len_1 = len(latest_event_ids) == 1 and len(new_latest_event_ids) == 1
+ if len_1:
+ all_single_prev_not_state = all(
+ len(event.prev_event_ids()) == 1 and not event.is_state()
+ for event, ctx in ev_ctx_rm
+ )
+ # Don't bother calculating state if they're just
+ # a long chain of single ancestor non-state events.
+ if all_single_prev_not_state:
+ return (new_forward_extremities, None)
+
+ state_delta_counter.inc()
+ if len(new_latest_event_ids) == 1:
+ state_delta_single_event_counter.inc()
+
+ # This is a fairly handwavey check to see if we could
+ # have guessed what the delta would have been when
+ # processing one of these events.
+ # What we're interested in is if the latest extremities
+ # were the same when we created the event as they are
+ # now. When this server creates a new event (as opposed
+ # to receiving it over federation) it will use the
+ # forward extremities as the prev_events, so we can
+ # guess this by looking at the prev_events and checking
+ # if they match the current forward extremities.
+ for ev, _ in ev_ctx_rm:
+ prev_event_ids = set(ev.prev_event_ids())
+ if latest_event_ids == prev_event_ids:
+ state_delta_reuse_delta_counter.inc()
+ break
+
+ logger.debug("Calculating state delta for room %s", room_id)
+ with Measure(self._clock, "persist_events.get_new_state_after_events"):
+ res = await self._get_new_state_after_events(
+ room_id,
+ ev_ctx_rm,
+ latest_event_ids,
+ new_latest_event_ids,
+ )
+ current_state, delta_ids, new_latest_event_ids = res
+
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
+ new_forward_extremities = new_latest_event_ids
+
+ # If either are not None then there has been a change,
+ # and we need to work out the delta (or use that
+ # given)
+ delta = None
+ if delta_ids is not None:
+ # If there is a delta we know that we've
+ # only added or replaced state, never
+ # removed keys entirely.
+ delta = DeltaState([], delta_ids)
+ elif current_state is not None:
+ with Measure(self._clock, "persist_events.calculate_state_delta"):
+ delta = await self._calculate_state_delta(room_id, current_state)
+
+ if delta:
+ # If we have a change of state then lets check
+ # whether we're actually still a member of the room,
+ # or if our last user left. If we're no longer in
+ # the room then we delete the current state and
+ # extremities.
+ is_still_joined = await self._is_server_still_joined(
+ room_id,
+ ev_ctx_rm,
+ delta,
+ )
+ if not is_still_joined:
+ logger.info("Server no longer in room %s", room_id)
+ delta.no_longer_in_room = True
+
+ return (new_forward_extremities, delta)
+
async def _calculate_new_extremities(
self,
room_id: str,
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 46957723a1..9f7959c45d 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -16,7 +16,6 @@ from itertools import chain
from typing import (
TYPE_CHECKING,
AbstractSet,
- Any,
Callable,
Collection,
Dict,
@@ -32,6 +31,7 @@ from typing import (
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.logging.opentracing import tag_args, trace
+from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.storage.roommember import ProfileInfo
from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker,
@@ -531,19 +531,9 @@ class StateStorageController:
@tag_args
async def get_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
- ) -> Tuple[int, List[Dict[str, Any]]]:
+ ) -> Tuple[int, List[StateDelta]]:
"""Fetch a list of room state changes since the given stream id
- Each entry in the result contains the following fields:
- - stream_id (int)
- - room_id (str)
- - type (str): event type
- - state_key (str):
- - event_id (str|None): new event_id for this state key. None if the
- state has been deleted.
- - prev_event_id (str|None): previous event_id for this state key. None
- if it's new state.
-
Args:
prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 695229bc91..05775425b7 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -19,7 +19,6 @@ import logging
import time
import types
from collections import defaultdict
-from sys import intern
from time import monotonic as monotonic_time
from typing import (
TYPE_CHECKING,
@@ -36,7 +35,6 @@ from typing import (
Tuple,
Type,
TypeVar,
- Union,
cast,
overload,
)
@@ -63,7 +61,8 @@ from synapse.storage.engines import (
BaseDatabaseEngine,
Psycopg2Engine,
PsycopgEngine,
- Sqlite3Engine, PostgresEngine,
+ Sqlite3Engine,
+ PostgresEngine,
)
from synapse.storage.types import Connection, Cursor, SQLQueryParameters
from synapse.util.async_helpers import delay_cancellation
@@ -416,7 +415,6 @@ class LoggingTransaction:
assert isinstance(self.database_engine, PostgresEngine)
if isinstance(self.database_engine, Psycopg2Engine):
-
from psycopg2.extras import execute_values
return self._do_execute(
@@ -470,6 +468,16 @@ class LoggingTransaction:
self._do_execute(self.txn.execute, sql, parameters)
def executemany(self, sql: str, *args: Any) -> None:
+ """Repeatedly execute the same piece of SQL with different parameters.
+
+ See https://peps.python.org/pep-0249/#executemany. Note in particular that
+
+ > Use of this method for an operation which produces one or more result sets
+ > constitutes undefined behavior
+
+ so you can't use this for e.g. a SELECT, an UPDATE ... RETURNING, or a
+ DELETE FROM... RETURNING.
+ """
# TODO: we should add a type for *args here. Looking at Cursor.executemany
# and DBAPI2 it ought to be Sequence[_Parameter], but we pass in
# Iterable[Iterable[Any]] in execute_batch and execute_values above, which mypy
@@ -661,13 +669,16 @@ class DatabasePool:
If the background updates have not completed, wait 15 sec and check again.
"""
- updates = await self.simple_select_list(
- "background_updates",
- keyvalues=None,
- retcols=["update_name"],
- desc="check_background_updates",
+ updates = cast(
+ List[Tuple[str]],
+ await self.simple_select_list(
+ "background_updates",
+ keyvalues=None,
+ retcols=["update_name"],
+ desc="check_background_updates",
+ ),
)
- background_update_names = [x["update_name"] for x in updates]
+ background_update_names = [x[0] for x in updates]
for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
if update_name not in background_update_names:
@@ -1085,57 +1096,20 @@ class DatabasePool:
self._db_pool.runWithConnection(inner_func, *args, **kwargs)
)
- @staticmethod
- def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
- """Converts a SQL cursor into an list of dicts.
-
- Args:
- cursor: The DBAPI cursor which has executed a query.
- Returns:
- A list of dicts where the key is the column header.
- """
- assert cursor.description is not None, "cursor.description was None"
- col_headers = [intern(str(column[0])) for column in cursor.description]
- results = [dict(zip(col_headers, row)) for row in cursor]
- return results
-
- @overload
- async def execute(
- self, desc: str, decoder: Literal[None], query: str, *args: Any
- ) -> List[Tuple[Any, ...]]:
- ...
-
- @overload
- async def execute(
- self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
- ) -> R:
- ...
-
- async def execute(
- self,
- desc: str,
- decoder: Optional[Callable[[Cursor], R]],
- query: str,
- *args: Any,
- ) -> Union[List[Tuple[Any, ...]], R]:
+ async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]:
"""Runs a single query for a result set.
Args:
desc: description of the transaction, for logging and metrics
- decoder - The function which can resolve the cursor results to
- something meaningful.
query - The query string to execute
*args - Query args.
Returns:
The result of decoder(results)
"""
- def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]:
+ def interaction(txn: LoggingTransaction) -> List[Tuple[Any, ...]]:
txn.execute(query, args)
- if decoder:
- return decoder(txn)
- else:
- return txn.fetchall()
+ return txn.fetchall()
return await self.runInteraction(desc, interaction)
@@ -1211,6 +1185,9 @@ class DatabasePool:
keys: list of column names
values: for each row, a list of values in the same order as `keys`
"""
+ # If there's nothing to insert, then skip executing the query.
+ if not values:
+ return
if isinstance(txn.database_engine, Psycopg2Engine):
# We use `execute_values` as it can be a lot faster than `execute_batch`,
@@ -1490,12 +1467,12 @@ class DatabasePool:
allvalues.update(values)
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
- sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % (
+ sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %sDO %s" % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues),
", ".join(k for k in keyvalues),
- f"WHERE {where_clause}" if where_clause else "",
+ f"WHERE {where_clause} " if where_clause else "",
latter,
)
txn.execute(sql, list(allvalues.values()))
@@ -1544,7 +1521,7 @@ class DatabasePool:
key_names: Collection[str],
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
- value_values: Iterable[Iterable[Any]],
+ value_values: Collection[Iterable[Any]],
) -> None:
"""
Upsert, many times.
@@ -1557,6 +1534,19 @@ class DatabasePool:
value_values: A list of each row's value column values.
Ignored if value_names is empty.
"""
+ # If there's nothing to upsert, then skip executing the query.
+ if not key_values:
+ return
+
+ # No value columns, therefore make a blank list so that the following
+ # zip() works correctly.
+ if not value_names:
+ value_values = [() for x in range(len(key_values))]
+ elif len(value_values) != len(key_values):
+ raise ValueError(
+ f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number."
+ )
+
if table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_many_txn_native_upsert(
txn, table, key_names, key_values, value_names, value_values
@@ -1591,10 +1581,6 @@ class DatabasePool:
value_values: A list of each row's value column values.
Ignored if value_names is empty.
"""
- # No value columns, therefore make a blank list so that the following
- # zip() works correctly.
- if not value_names:
- value_values = [() for x in range(len(key_values))]
# Lock the table just once, to prevent it being done once per row.
# Note that, according to Postgres' documentation, once obtained,
@@ -1632,10 +1618,7 @@ class DatabasePool:
allnames.extend(value_names)
if not value_names:
- # No value columns, therefore make a blank list so that the
- # following zip() works correctly.
latter = "NOTHING"
- value_values = [() for x in range(len(key_values))]
else:
latter = "UPDATE SET " + ", ".join(
k + "=EXCLUDED." + k for k in value_names
@@ -1867,9 +1850,9 @@ class DatabasePool:
keyvalues: Optional[Dict[str, Any]],
retcols: Collection[str],
desc: str = "simple_select_list",
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
+ more rows, returning the result as a list of tuples.
Args:
table: the table name
@@ -1880,8 +1863,7 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
Returns:
- A list of dictionaries, one per result row, each a mapping between the
- column names from `retcols` and that column's value for the row.
+ A list of tuples, one per result row, each the retcolumn's value for the row.
"""
return await self.runInteraction(
desc,
@@ -1899,9 +1881,9 @@ class DatabasePool:
table: str,
keyvalues: Optional[Dict[str, Any]],
retcols: Iterable[str],
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
+ more rows, returning the result as a list of tuples.
Args:
txn: Transaction object
@@ -1912,8 +1894,7 @@ class DatabasePool:
retcols: the names of the columns to return
Returns:
- A list of dictionaries, one per result row, each a mapping between the
- column names from `retcols` and that column's value for the row.
+ A list of tuples, one per result row, each the retcolumn's value for the row.
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1926,7 +1907,7 @@ class DatabasePool:
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
txn.execute(sql)
- return cls.cursor_to_dict(txn)
+ return txn.fetchall()
async def simple_select_many_batch(
self,
@@ -1937,9 +1918,9 @@ class DatabasePool:
keyvalues: Optional[Dict[str, Any]] = None,
desc: str = "simple_select_many_batch",
batch_size: int = 100,
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
+ more rows.
Filters rows by whether the value of `column` is in `iterable`.
@@ -1951,10 +1932,13 @@ class DatabasePool:
keyvalues: dict of column names and values to select the rows with
desc: description of the transaction, for logging and metrics
batch_size: the number of rows for each select query
+
+ Returns:
+ The results as a list of tuples.
"""
keyvalues = keyvalues or {}
- results: List[Dict[str, Any]] = []
+ results: List[Tuple[Any, ...]] = []
for chunk in batch_iter(iterable, batch_size):
rows = await self.runInteraction(
@@ -1981,9 +1965,9 @@ class DatabasePool:
iterable: Collection[Any],
keyvalues: Dict[str, Any],
retcols: Iterable[str],
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
+ more rows.
Filters rows by whether the value of `column` is in `iterable`.
@@ -1994,7 +1978,11 @@ class DatabasePool:
iterable: list
keyvalues: dict of column names and values to select the rows with
retcols: list of strings giving the names of the columns to return
+
+ Returns:
+ The results as a list of tuples.
"""
+ # If there's nothing to select, then skip executing the query.
if not iterable:
return []
@@ -2012,7 +2000,7 @@ class DatabasePool:
)
txn.execute(sql, values)
- return cls.cursor_to_dict(txn)
+ return txn.fetchall()
async def simple_update(
self,
@@ -2129,13 +2117,13 @@ class DatabasePool:
raise ValueError(
f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number."
)
+ # If there is nothing to update, then skip executing the query.
+ if not key_values:
+ return
# List of tuples of (value values, then key values)
# (This matches the order needed for the query)
- args = [tuple(x) + tuple(y) for x, y in zip(value_values, key_values)]
-
- for ks, vs in zip(key_values, value_values):
- args.append(tuple(vs) + tuple(ks))
+ args = [tuple(vv) + tuple(kv) for vv, kv in zip(value_values, key_values)]
# 'col1 = ?, col2 = ?, ...'
set_clause = ", ".join(f"{n} = ?" for n in value_names)
@@ -2147,9 +2135,7 @@ class DatabasePool:
where_clause = ""
# UPDATE mytable SET col1 = ?, col2 = ? WHERE col3 = ? AND col4 = ?
- sql = f"""
- UPDATE {table} SET {set_clause} {where_clause}
- """
+ sql = f"UPDATE {table} SET {set_clause} {where_clause}"
txn.execute_batch(sql, args)
@@ -2365,11 +2351,10 @@ class DatabasePool:
Returns:
Number rows deleted
"""
+ # If there's nothing to delete, then skip executing the query.
if not values:
return 0
- sql = "DELETE FROM %s" % table
-
clause, values = make_in_list_sql_clause(txn.database_engine, column, values)
clauses = [clause]
@@ -2377,8 +2362,7 @@ class DatabasePool:
clauses.append("%s = ?" % (key,))
values.append(value)
- if clauses:
- sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
+ sql = "DELETE FROM %s WHERE %s" % (table, " AND ".join(clauses))
txn.execute(sql, values)
return txn.rowcount
@@ -2481,7 +2465,7 @@ class DatabasePool:
keyvalues: Optional[Dict[str, Any]] = None,
exclude_keyvalues: Optional[Dict[str, Any]] = None,
order_direction: str = "ASC",
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[Any, ...]]:
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
@@ -2510,7 +2494,7 @@ class DatabasePool:
order_direction: Whether the results should be ordered "ASC" or "DESC".
Returns:
- The result as a list of dictionaries.
+ The result as a list of tuples.
"""
if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
@@ -2537,69 +2521,7 @@ class DatabasePool:
)
txn.execute(sql, arg_list + [limit, start])
- return cls.cursor_to_dict(txn)
-
- async def simple_search_list(
- self,
- table: str,
- term: Optional[str],
- col: str,
- retcols: Collection[str],
- desc: str = "simple_search_list",
- ) -> Optional[List[Dict[str, Any]]]:
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- table: the table name
- term: term for searching the table matched to a column.
- col: column to query term should be matched to
- retcols: the names of the columns to return
-
- Returns:
- A list of dictionaries or None.
- """
-
- return await self.runInteraction(
- desc,
- self.simple_search_list_txn,
- table,
- term,
- col,
- retcols,
- db_autocommit=True,
- )
-
- @classmethod
- def simple_search_list_txn(
- cls,
- txn: LoggingTransaction,
- table: str,
- term: Optional[str],
- col: str,
- retcols: Iterable[str],
- ) -> Optional[List[Dict[str, Any]]]:
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- txn: Transaction object
- table: the table name
- term: term for searching the table matched to a column.
- col: column to query term should be matched to
- retcols: the names of the columns to return
-
- Returns:
- None if no term is given, otherwise a list of dictionaries.
- """
- if term:
- sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
- termvalues = ["%%" + term + "%%"]
- txn.execute(sql, termvalues)
- else:
- return None
-
- return cls.cursor_to_dict(txn)
+ return txn.fetchall()
def make_in_list_sql_clause(
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 101403578c..89f4077351 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,7 +15,9 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
+
+import attr
from synapse.api.constants import Direction
from synapse.config.homeserver import HomeServerConfig
@@ -28,7 +30,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor
-from synapse.types import JsonDict, get_domain_from_id
+from synapse.types import get_domain_from_id
from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
@@ -82,6 +84,25 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UserPaginateResponse:
+ """This is very similar to UserInfo, but not quite the same."""
+
+ name: str
+ user_type: Optional[str]
+ is_guest: bool
+ admin: bool
+ deactivated: bool
+ shadow_banned: bool
+ displayname: Optional[str]
+ avatar_url: Optional[str]
+ creation_ts: Optional[int]
+ approved: bool
+ erased: bool
+ last_seen_ts: int
+ locked: bool
+
+
class DataStore(
EventsBackgroundUpdatesStore,
ExperimentalFeaturesStore,
@@ -142,26 +163,6 @@ class DataStore(
super().__init__(database, db_conn, hs)
- async def get_users(self) -> List[JsonDict]:
- """Function to retrieve a list of users in users table.
-
- Returns:
- A list of dictionaries representing users.
- """
- return await self.db_pool.simple_select_list(
- table="users",
- keyvalues={},
- retcols=[
- "name",
- "password_hash",
- "is_guest",
- "admin",
- "user_type",
- "deactivated",
- ],
- desc="get_users",
- )
-
async def get_users_paginate(
self,
start: int,
@@ -176,7 +177,7 @@ class DataStore(
approved: bool = True,
not_user_types: Optional[List[str]] = None,
locked: bool = False,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[UserPaginateResponse], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
total number of users matching the filter criteria.
@@ -202,7 +203,7 @@ class DataStore(
def get_users_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[UserPaginateResponse], int]:
filters = []
args: list = []
@@ -302,13 +303,24 @@ class DataStore(
"""
args += [limit, start]
txn.execute(sql, args)
- users = self.db_pool.cursor_to_dict(txn)
-
- # some of those boolean values are returned as integers when we're on SQLite
- columns_to_boolify = ["erased"]
- for user in users:
- for column in columns_to_boolify:
- user[column] = bool(user[column])
+ users = [
+ UserPaginateResponse(
+ name=row[0],
+ user_type=row[1],
+ is_guest=bool(row[2]),
+ admin=bool(row[3]),
+ deactivated=bool(row[4]),
+ shadow_banned=bool(row[5]),
+ displayname=row[6],
+ avatar_url=row[7],
+ creation_ts=row[8],
+ approved=bool(row[9]),
+ erased=bool(row[10]),
+ last_seen_ts=row[11],
+ locked=bool(row[12]),
+ )
+ for row in txn
+ ]
return users, count
@@ -316,7 +328,11 @@ class DataStore(
"get_users_paginate_txn", get_users_paginate_txn
)
- async def search_users(self, term: str) -> Optional[List[JsonDict]]:
+ async def search_users(
+ self, term: str
+ ) -> List[
+ Tuple[str, Optional[str], Union[int, bool], Union[int, bool], Optional[str]]
+ ]:
"""Function to search users list for one or more users with
the matched term.
@@ -324,15 +340,37 @@ class DataStore(
term: search term
Returns:
- A list of dictionaries or None.
+ A list of tuples of name, password_hash, is_guest, admin, user_type or None.
"""
- return await self.db_pool.simple_search_list(
- table="users",
- term=term,
- col="name",
- retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
- desc="search_users",
- )
+
+ def search_users(
+ txn: LoggingTransaction,
+ ) -> List[
+ Tuple[str, Optional[str], Union[int, bool], Union[int, bool], Optional[str]]
+ ]:
+ search_term = "%%" + term + "%%"
+
+ sql = """
+ SELECT name, password_hash, is_guest, admin, user_type
+ FROM users
+ WHERE name LIKE ?
+ """
+ txn.execute(sql, (search_term,))
+
+ return cast(
+ List[
+ Tuple[
+ str,
+ Optional[str],
+ Union[int, bool],
+ Union[int, bool],
+ Optional[str],
+ ]
+ ],
+ txn.fetchall(),
+ )
+
+ return await self.db_pool.runInteraction("search_users", search_users)
def check_database_before_upgrade(
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 80f146dd53..d7482a1f4e 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -94,7 +94,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
hs.get_replication_notifier(),
"room_account_data",
"stream_id",
- extra_tables=[("room_tags_revisions", "stream_id")],
+ extra_tables=[
+ ("account_data", "stream_id"),
+ ("room_tags_revisions", "stream_id"),
+ ],
is_writer=self._instance_name in hs.config.worker.writers.account_data,
)
@@ -103,6 +106,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
"AccountDataAndTagsChangeCache", account_max
)
+ self.db_pool.updates.register_background_index_update(
+ update_name="room_account_data_index_room_id",
+ index_name="room_account_data_room_id",
+ table="room_account_data",
+ columns=("room_id",),
+ )
+
self.db_pool.updates.register_background_update_handler(
"delete_account_data_for_deactivated_users",
self._delete_account_data_for_deactivated_users,
@@ -151,10 +161,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sql += " AND content != '{}'"
txn.execute(sql, (user_id,))
- rows = self.db_pool.cursor_to_dict(txn)
return {
- row["account_data_type"]: db_to_json(row["content"]) for row in rows
+ account_data_type: db_to_json(content)
+ for account_data_type, content in txn
}
return await self.db_pool.runInteraction(
@@ -196,13 +206,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sql += " AND content != '{}'"
txn.execute(sql, (user_id,))
- rows = self.db_pool.cursor_to_dict(txn)
by_room: Dict[str, Dict[str, JsonDict]] = {}
- for row in rows:
- room_data = by_room.setdefault(row["room_id"], {})
+ for room_id, account_data_type, content in txn:
+ room_data = by_room.setdefault(room_id, {})
- room_data[row["account_data_type"]] = db_to_json(row["content"])
+ room_data[account_data_type] = db_to_json(content)
return by_room
@@ -277,16 +286,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def get_account_data_for_room_txn(
txn: LoggingTransaction,
- ) -> Dict[str, JsonDict]:
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "room_account_data",
- {"user_id": user_id, "room_id": room_id},
- ["account_data_type", "content"],
+ ) -> Dict[str, JsonMapping]:
+ rows = cast(
+ List[Tuple[str, str]],
+ self.db_pool.simple_select_list_txn(
+ txn,
+ table="room_account_data",
+ keyvalues={"user_id": user_id, "room_id": room_id},
+ retcols=["account_data_type", "content"],
+ ),
)
return {
- row["account_data_type"]: db_to_json(row["content"]) for row in rows
+ account_data_type: db_to_json(content)
+ for account_data_type, content in rows
}
return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 0553a0621a..fa7d1c469a 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -14,17 +14,7 @@
# limitations under the License.
import logging
import re
-from typing import (
- TYPE_CHECKING,
- Any,
- Dict,
- List,
- Optional,
- Pattern,
- Sequence,
- Tuple,
- cast,
-)
+from typing import TYPE_CHECKING, List, Optional, Pattern, Sequence, Tuple, cast
from synapse.appservice import (
ApplicationService,
@@ -207,16 +197,21 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A list of ApplicationServices, which may be empty.
"""
- results = await self.db_pool.simple_select_list(
- "application_services_state", {"state": state.value}, ["as_id"]
+ results = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_list(
+ table="application_services_state",
+ keyvalues={"state": state.value},
+ retcols=("as_id",),
+ ),
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
services = []
- for res in results:
+ for (as_id,) in results:
for service in as_list:
- if service.id == res["as_id"]:
+ if service.id == as_id:
services.append(service)
return services
@@ -353,21 +348,15 @@ class ApplicationServiceTransactionWorkerStore(
def _get_oldest_unsent_txn(
txn: LoggingTransaction,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[Tuple[int, str]]:
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
- "SELECT * FROM application_services_txns WHERE as_id=?"
+ "SELECT txn_id, event_ids FROM application_services_txns WHERE as_id=?"
" ORDER BY txn_id ASC LIMIT 1",
(service.id,),
)
- rows = self.db_pool.cursor_to_dict(txn)
- if not rows:
- return None
-
- entry = rows[0]
-
- return entry
+ return cast(Optional[Tuple[int, str]], txn.fetchone())
entry = await self.db_pool.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
@@ -376,8 +365,9 @@ class ApplicationServiceTransactionWorkerStore(
if not entry:
return None
- event_ids = db_to_json(entry["event_ids"])
+ txn_id, event_ids_str = entry
+ event_ids = db_to_json(event_ids_str)
events = await self.get_events_as_list(event_ids)
# TODO: to-device messages, one-time key counts, device list summaries and unused
@@ -385,7 +375,7 @@ class ApplicationServiceTransactionWorkerStore(
# We likely want to populate those for reliability.
return AppServiceTransaction(
service=service,
- id=entry["txn_id"],
+ id=txn_id,
events=events,
ephemeral=[],
to_device_messages=[],
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2fbd389c71..4d0470ffd9 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -23,6 +23,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.replication.tcp.streams import BackfillStream, CachesStream
from synapse.replication.tcp.streams.events import (
EventsStream,
+ EventsStreamAllStateRow,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
@@ -264,6 +265,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
(data.state_key,)
)
self.get_rooms_for_user.invalidate((data.state_key,)) # type: ignore[attr-defined]
+ elif row.type == EventsStreamAllStateRow.TypeId:
+ assert isinstance(data, EventsStreamAllStateRow)
+ # Similar to the above, but the entire caches are invalidated. This is
+ # unfortunate for the membership caches, but should recover quickly.
+ self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) # type: ignore[attr-defined]
+ self.get_rooms_for_user_with_stream_ordering.invalidate_all() # type: ignore[attr-defined]
+ self.get_rooms_for_user.invalidate_all() # type: ignore[attr-defined]
else:
raise Exception("Unknown events stream row type %s" % (row.type,))
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 58177ecec1..711fdddd4e 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -93,7 +93,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
"""
rows = await self.db_pool.execute(
- "_censor_redactions_fetch", None, sql, before_ts, 100
+ "_censor_redactions_fetch", sql, before_ts, 100
)
updates = []
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 7da47c3dd7..c006129625 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
+import attr
from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -42,7 +43,8 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000
-class DeviceLastConnectionInfo(TypedDict):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class DeviceLastConnectionInfo:
"""Metadata for the last connection seen for a user and device combination"""
# These types must match the columns in the `devices` table
@@ -499,8 +501,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
device_id: If None fetches all devices for the user
Returns:
- A dictionary mapping a tuple of (user_id, device_id) to dicts, with
- keys giving the column names from the devices table.
+ A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
"""
keyvalues = {"user_id": user_id}
@@ -508,7 +509,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
keyvalues["device_id"] = device_id
res = cast(
- List[DeviceLastConnectionInfo],
+ List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
@@ -516,7 +517,16 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
),
)
- return {(d["user_id"], d["device_id"]): d for d in res}
+ return {
+ (user_id, device_id): DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id,
+ ip=ip,
+ user_agent=user_agent,
+ last_seen=last_seen,
+ )
+ for user_id, ip, user_agent, device_id, last_seen in res
+ }
async def _get_user_ip_and_agents_from_database(
self, user: UserID, since_ts: int = 0
@@ -683,8 +693,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
device_id: If None fetches all devices for the user
Returns:
- A dictionary mapping a tuple of (user_id, device_id) to dicts, with
- keys giving the column names from the devices table.
+ A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
"""
ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)
@@ -705,13 +714,13 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
continue
if not device_id or did == device_id:
- ret[(user_id, did)] = {
- "user_id": user_id,
- "ip": ip,
- "user_agent": user_agent,
- "device_id": did,
- "last_seen": last_seen,
- }
+ ret[(user_id, did)] = DeviceLastConnectionInfo(
+ user_id=user_id,
+ ip=ip,
+ user_agent=user_agent,
+ device_id=did,
+ last_seen=last_seen,
+ )
return ret
async def get_user_ip_and_agents(
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 744e98c6d0..3e7425d4a6 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -344,18 +344,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# Note that this is more efficient than just dropping `device_id` from the query,
# since device_inbox has an index on `(user_id, device_id, stream_id)`
if not device_ids_to_query:
- user_device_dicts = self.db_pool.simple_select_many_txn(
- txn,
- table="devices",
- column="user_id",
- iterable=user_ids_to_query,
- keyvalues={"hidden": False},
- retcols=("device_id",),
+ user_device_dicts = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="devices",
+ column="user_id",
+ iterable=user_ids_to_query,
+ keyvalues={"hidden": False},
+ retcols=("device_id",),
+ ),
)
- device_ids_to_query.update(
- {row["device_id"] for row in user_device_dicts}
- )
+ device_ids_to_query.update({row[0] for row in user_device_dicts})
if not device_ids_to_query:
# We've ended up with no devices to query.
@@ -449,7 +450,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
user_id: str,
device_id: Optional[str],
up_to_stream_id: int,
- limit: int,
+ limit: Optional[int] = None,
) -> int:
"""
Args:
@@ -477,17 +478,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
log_kv({"message": "No changes in cache since last check"})
return 0
- ROW_ID_NAME = self.database_engine.row_id_name
-
def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
+ limit_statement = "" if limit is None else f"LIMIT {limit}"
sql = f"""
- DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN (
- SELECT {ROW_ID_NAME} FROM device_inbox
- WHERE user_id = ? AND device_id = ? AND stream_id <= ?
- LIMIT {limit}
+ DELETE FROM device_inbox WHERE user_id = ? AND device_id = ? AND stream_id <= (
+ SELECT MAX(stream_id) FROM (
+ SELECT stream_id FROM device_inbox
+ WHERE user_id = ? AND device_id = ? AND stream_id <= ?
+ ORDER BY stream_id
+ {limit_statement}
+ ) AS q1
)
"""
- txn.execute(sql, (user_id, device_id, up_to_stream_id))
+ txn.execute(sql, (user_id, device_id, user_id, device_id, up_to_stream_id))
return txn.rowcount
count = await self.db_pool.runInteraction(
@@ -845,20 +848,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# We exclude hidden devices (such as cross-signing keys) here as they are
# not expected to receive to-device messages.
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="devices",
- keyvalues={"user_id": user_id, "hidden": False},
- column="device_id",
- iterable=devices,
- retcols=("device_id",),
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="devices",
+ keyvalues={"user_id": user_id, "hidden": False},
+ column="device_id",
+ iterable=devices,
+ retcols=("device_id",),
+ ),
)
- for row in rows:
+ for (device_id,) in rows:
# Only insert into the local inbox if the device exists on
# this server
- device_id = row["device_id"]
-
with start_active_span("serialise_to_device_message"):
msg = messages_by_device[device_id]
set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"])
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index df596f35f9..04d12a876c 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -283,7 +283,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
allow_none=True,
)
- async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
+ async def get_devices_by_user(
+ self, user_id: str
+ ) -> Dict[str, Dict[str, Optional[str]]]:
"""Retrieve all of a user's registered devices. Only returns devices
that are not marked as hidden.
@@ -291,20 +293,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_id:
Returns:
A mapping from device_id to a dict containing "device_id", "user_id"
- and "display_name" for each device.
+ and "display_name" for each device. Display name may be null.
"""
- devices = await self.db_pool.simple_select_list(
- table="devices",
- keyvalues={"user_id": user_id, "hidden": False},
- retcols=("user_id", "device_id", "display_name"),
- desc="get_devices_by_user",
+ devices = cast(
+ List[Tuple[str, str, Optional[str]]],
+ await self.db_pool.simple_select_list(
+ table="devices",
+ keyvalues={"user_id": user_id, "hidden": False},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_devices_by_user",
+ ),
)
- return {d["device_id"]: d for d in devices}
+ return {
+ d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]}
+ for d in devices
+ }
async def get_devices_by_auth_provider_session_id(
self, auth_provider_id: str, auth_provider_session_id: str
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, str]]:
"""Retrieve the list of devices associated with a SSO IdP session ID.
Args:
@@ -313,14 +321,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
Returns:
A list of dicts containing the device_id and the user_id of each device
"""
- return await self.db_pool.simple_select_list(
- table="device_auth_providers",
- keyvalues={
- "auth_provider_id": auth_provider_id,
- "auth_provider_session_id": auth_provider_session_id,
- },
- retcols=("user_id", "device_id"),
- desc="get_devices_by_auth_provider_session_id",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="device_auth_providers",
+ keyvalues={
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ retcols=("user_id", "device_id"),
+ desc="get_devices_by_auth_provider_session_id",
+ ),
)
@trace
@@ -692,7 +703,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
key_names=("destination", "user_id"),
key_values=[(destination, user_id) for user_id, _ in rows],
value_names=("stream_id",),
- value_values=((stream_id,) for _, stream_id in rows),
+ value_values=[(stream_id,) for _, stream_id in rows],
)
# Delete all sent outbound pokes
@@ -821,15 +832,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
async def get_cached_devices_for_user(
self, user_id: str
) -> Mapping[str, JsonMapping]:
- devices = await self.db_pool.simple_select_list(
- table="device_lists_remote_cache",
- keyvalues={"user_id": user_id},
- retcols=("device_id", "content"),
- desc="get_cached_devices_for_user",
+ devices = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="device_lists_remote_cache",
+ keyvalues={"user_id": user_id},
+ retcols=("device_id", "content"),
+ desc="get_cached_devices_for_user",
+ ),
)
- return {
- device["device_id"]: db_to_json(device["content"]) for device in devices
- }
+ return {device[0]: db_to_json(device[1]) for device in devices}
def get_cached_device_list_changes(
self,
@@ -882,7 +894,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
rows = await self.db_pool.execute(
"get_all_devices_changed",
- None,
sql,
from_key,
to_key,
@@ -966,7 +977,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
WHERE from_user_id = ? AND stream_id > ?
"""
rows = await self.db_pool.execute(
- "get_users_whose_signatures_changed", None, sql, user_id, from_key
+ "get_users_whose_signatures_changed", sql, user_id, from_key
)
return {user for row in rows for user in db_to_json(row[0])}
else:
@@ -1052,16 +1063,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
async def get_device_list_last_stream_id_for_remotes(
self, user_ids: Iterable[str]
) -> Mapping[str, Optional[str]]:
- rows = await self.db_pool.simple_select_many_batch(
- table="device_lists_remote_extremeties",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id", "stream_id"),
- desc="get_device_list_last_stream_id_for_remotes",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="device_lists_remote_extremeties",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id", "stream_id"),
+ desc="get_device_list_last_stream_id_for_remotes",
+ ),
)
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
- results.update({row["user_id"]: row["stream_id"] for row in rows})
+ results.update(rows)
return results
@@ -1077,22 +1091,28 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
The IDs of users whose device lists need resync.
"""
if user_ids:
- rows = await self.db_pool.simple_select_many_batch(
- table="device_lists_remote_resync",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id",),
- desc="get_user_ids_requiring_device_list_resync_with_iterable",
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="device_lists_remote_resync",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync_with_iterable",
+ ),
)
else:
- rows = await self.db_pool.simple_select_list(
- table="device_lists_remote_resync",
- keyvalues=None,
- retcols=("user_id",),
- desc="get_user_ids_requiring_device_list_resync",
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_list(
+ table="device_lists_remote_resync",
+ keyvalues=None,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync",
+ ),
)
- return {row["user_id"] for row in rows}
+ return {row[0] for row in rows}
async def mark_remote_users_device_caches_as_stale(
self, user_ids: StrCollection
@@ -1413,13 +1433,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_devices_not_accessed_since_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, str]]:
+ ) -> List[Tuple[str, str]]:
sql = """
SELECT user_id, device_id
FROM devices WHERE last_seen < ? AND hidden = FALSE
"""
txn.execute(sql, (since_ms,))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, str]], txn.fetchall())
rows = await self.db_pool.runInteraction(
"get_devices_not_accessed_since",
@@ -1427,11 +1447,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
devices: Dict[str, List[str]] = {}
- for row in rows:
+ for user_id, device_id in rows:
# Remote devices are never stale from our point of view.
- if self.hs.is_mine_id(row["user_id"]):
- user_devices = devices.setdefault(row["user_id"], [])
- user_devices.append(row["device_id"])
+ if self.hs.is_mine_id(user_id):
+ user_devices = devices.setdefault(user_id, [])
+ user_devices.append(device_id)
return devices
@@ -1600,7 +1620,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
#
# For each duplicate, we delete all the existing rows and put one back.
- KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
last_row = progress.get(
"last_row",
{"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
@@ -1608,44 +1627,62 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
def _txn(txn: LoggingTransaction) -> int:
clause, args = make_tuple_comparison_clause(
- [(x, last_row[x]) for x in KEY_COLS]
+ [
+ ("stream_id", last_row["stream_id"]),
+ ("destination", last_row["destination"]),
+ ("user_id", last_row["user_id"]),
+ ("device_id", last_row["device_id"]),
+ ]
)
- sql = """
+ sql = f"""
SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
FROM device_lists_outbound_pokes
- WHERE %s
- GROUP BY %s
+ WHERE {clause}
+ GROUP BY stream_id, destination, user_id, device_id
HAVING count(*) > 1
- ORDER BY %s
+ ORDER BY stream_id, destination, user_id, device_id
LIMIT ?
- """ % (
- clause, # WHERE
- ",".join(KEY_COLS), # GROUP BY
- ",".join(KEY_COLS), # ORDER BY
- )
+ """
txn.execute(sql, args + [batch_size])
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
- row = None
- for row in rows:
+ stream_id, destination, user_id, device_id = None, None, None, None
+ for stream_id, destination, user_id, device_id, _ in rows:
self.db_pool.simple_delete_txn(
txn,
"device_lists_outbound_pokes",
- {x: row[x] for x in KEY_COLS},
+ {
+ "stream_id": stream_id,
+ "destination": destination,
+ "user_id": user_id,
+ "device_id": device_id,
+ },
)
- row["sent"] = False
self.db_pool.simple_insert_txn(
txn,
"device_lists_outbound_pokes",
- row,
+ {
+ "stream_id": stream_id,
+ "destination": destination,
+ "user_id": user_id,
+ "device_id": device_id,
+ "sent": False,
+ },
)
- if row:
+ if rows:
self.db_pool.updates._background_update_progress_txn(
txn,
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
- {"last_row": row},
+ {
+ "last_row": {
+ "stream_id": stream_id,
+ "destination": destination,
+ "user_id": user_id,
+ "device_id": device_id,
+ }
+ },
)
return len(rows)
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index d01f28cc80..ad904a26a6 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast
from typing_extensions import Literal, TypedDict
@@ -53,6 +53,13 @@ class EndToEndRoomKeyBackgroundStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
+ self.db_pool.updates.register_background_index_update(
+ update_name="e2e_room_keys_index_room_id",
+ index_name="e2e_room_keys_room_id",
+ table="e2e_room_keys",
+ columns=("room_id",),
+ )
+
self.db_pool.updates.register_background_update_handler(
"delete_e2e_backup_keys_for_deactivated_users",
self._delete_e2e_backup_keys_for_deactivated_users,
@@ -208,7 +215,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
"message": "Set room key",
"room_id": room_id,
"session_id": session_id,
- StreamKeyType.ROOM: room_key,
+ StreamKeyType.ROOM.value: room_key,
}
)
@@ -267,32 +274,41 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
if session_id:
keyvalues["session_id"] = session_id
- rows = await self.db_pool.simple_select_list(
- table="e2e_room_keys",
- keyvalues=keyvalues,
- retcols=(
- "user_id",
- "room_id",
- "session_id",
- "first_message_index",
- "forwarded_count",
- "is_verified",
- "session_data",
+ rows = cast(
+ List[Tuple[str, str, int, int, int, str]],
+ await self.db_pool.simple_select_list(
+ table="e2e_room_keys",
+ keyvalues=keyvalues,
+ retcols=(
+ "room_id",
+ "session_id",
+ "first_message_index",
+ "forwarded_count",
+ "is_verified",
+ "session_data",
+ ),
+ desc="get_e2e_room_keys",
),
- desc="get_e2e_room_keys",
)
sessions: Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
] = {"rooms": {}}
- for row in rows:
- room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
- room_entry["sessions"][row["session_id"]] = {
- "first_message_index": row["first_message_index"],
- "forwarded_count": row["forwarded_count"],
+ for (
+ room_id,
+ session_id,
+ first_message_index,
+ forwarded_count,
+ is_verified,
+ session_data,
+ ) in rows:
+ room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}})
+ room_entry["sessions"][session_id] = {
+ "first_message_index": first_message_index,
+ "forwarded_count": forwarded_count,
# is_verified must be returned to the client as a boolean
- "is_verified": bool(row["is_verified"]),
- "session_data": db_to_json(row["session_data"]),
+ "is_verified": bool(is_verified),
+ "session_data": db_to_json(session_data),
}
return sessions
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 89fac23f93..4f96ac25c7 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -24,6 +24,7 @@ from typing import (
Mapping,
Optional,
Sequence,
+ Set,
Tuple,
Union,
cast,
@@ -155,7 +156,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"""
rows = await self.db_pool.execute(
"get_e2e_device_keys_for_federation_query_check",
- None,
sql,
now_stream_id,
user_id,
@@ -493,15 +493,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
A map from (algorithm, key_id) to json string for key
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="e2e_one_time_keys_json",
- column="key_id",
- iterable=key_ids,
- retcols=("algorithm", "key_id", "key_json"),
- keyvalues={"user_id": user_id, "device_id": device_id},
- desc="add_e2e_one_time_keys_check",
+ rows = cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="e2e_one_time_keys_json",
+ column="key_id",
+ iterable=key_ids,
+ retcols=("algorithm", "key_id", "key_json"),
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="add_e2e_one_time_keys_check",
+ ),
)
- result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
+ result = {(algorithm, key_id): key_json for algorithm, key_id, key_json in rows}
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
return result
@@ -921,14 +924,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
}
txn.execute(sql, params)
- rows = self.db_pool.cursor_to_dict(txn)
- for row in rows:
- user_id = row["user_id"]
- key_type = row["keytype"]
- key = db_to_json(row["keydata"])
+ for user_id, key_type, key_data, _ in txn:
user_keys = result.setdefault(user_id, {})
- user_keys[key_type] = key
+ user_keys[key_type] = db_to_json(key_data)
return result
@@ -988,13 +987,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
query_params.extend(item)
txn.execute(sql, query_params)
- rows = self.db_pool.cursor_to_dict(txn)
# and add the signatures to the appropriate keys
- for row in rows:
- key_id: str = row["key_id"]
- target_user_id: str = row["target_user_id"]
- target_device_id: str = row["target_device_id"]
+ for target_user_id, target_device_id, key_id, signature in txn:
key_type = devices[(target_user_id, target_device_id)]
# We need to copy everything, because the result may have come
# from the cache. dict.copy only does a shallow copy, so we
@@ -1012,13 +1007,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
].copy()
if from_user_id in signatures:
user_sigs = signatures[from_user_id] = signatures[from_user_id]
- user_sigs[key_id] = row["signature"]
+ user_sigs[key_id] = signature
else:
- signatures[from_user_id] = {key_id: row["signature"]}
+ signatures[from_user_id] = {key_id: signature}
else:
- target_user_key["signatures"] = {
- from_user_id: {key_id: row["signature"]}
- }
+ target_user_key["signatures"] = {from_user_id: {key_id: signature}}
return keys
@@ -1118,7 +1111,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
...
async def claim_e2e_one_time_keys(
- self, query_list: Iterable[Tuple[str, str, str, int]]
+ self, query_list: Collection[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
@@ -1128,131 +1121,63 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
- A tuple pf:
+ A tuple (results, missing) of:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
- A copy of the input which has not been fulfilled.
+ A copy of the input which has not been fulfilled. The returned counts
+ may be less than the input counts. In this case, the returned counts
+ are the number of claims that were not fulfilled.
"""
-
- @trace
- def _claim_e2e_one_time_key_simple(
- txn: LoggingTransaction,
- user_id: str,
- device_id: str,
- algorithm: str,
- count: int,
- ) -> List[Tuple[str, str]]:
- """Claim OTK for device for DBs that don't support RETURNING.
-
- Returns:
- A tuple of key name (algorithm + key ID) and key JSON, if an
- OTK was found.
- """
-
- sql = """
- SELECT key_id, key_json FROM e2e_one_time_keys_json
- WHERE user_id = ? AND device_id = ? AND algorithm = ?
- LIMIT ?
- """
-
- txn.execute(sql, (user_id, device_id, algorithm, count))
- otk_rows = list(txn)
- if not otk_rows:
- return []
-
- self.db_pool.simple_delete_many_txn(
- txn,
- table="e2e_one_time_keys_json",
- column="key_id",
- values=[otk_row[0] for otk_row in otk_rows],
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- "algorithm": algorithm,
- },
- )
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
- return [
- (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
- ]
-
- @trace
- def _claim_e2e_one_time_key_returning(
- txn: LoggingTransaction,
- user_id: str,
- device_id: str,
- algorithm: str,
- count: int,
- ) -> List[Tuple[str, str]]:
- """Claim OTK for device for DBs that support RETURNING.
-
- Returns:
- A tuple of key name (algorithm + key ID) and key JSON, if an
- OTK was found.
- """
-
- # We can use RETURNING to do the fetch and DELETE in once step.
- sql = """
- DELETE FROM e2e_one_time_keys_json
- WHERE user_id = ? AND device_id = ? AND algorithm = ?
- AND key_id IN (
- SELECT key_id FROM e2e_one_time_keys_json
- WHERE user_id = ? AND device_id = ? AND algorithm = ?
- LIMIT ?
- )
- RETURNING key_id, key_json
- """
-
- txn.execute(
- sql,
- (user_id, device_id, algorithm, user_id, device_id, algorithm, count),
- )
- otk_rows = list(txn)
- if not otk_rows:
- return []
-
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
- return [
- (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
- ]
-
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
missing: List[Tuple[str, str, str, int]] = []
- for user_id, device_id, algorithm, count in query_list:
- if self.database_engine.supports_returning:
- # If we support RETURNING clause we can use a single query that
- # allows us to use autocommit mode.
- _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
- db_autocommit = True
- else:
- _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
- db_autocommit = False
-
- claim_rows = await self.db_pool.runInteraction(
+ if isinstance(self.database_engine, PostgresEngine):
+ # If we can use execute_values we can use a single batch query
+ # in autocommit mode.
+ unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {}
+ for user_id, device_id, algorithm, count in query_list:
+ unfulfilled_claim_counts[user_id, device_id, algorithm] = count
+
+ bulk_claims = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
- _claim_e2e_one_time_key,
- user_id,
- device_id,
- algorithm,
- count,
- db_autocommit=db_autocommit,
+ self._claim_e2e_one_time_keys_bulk,
+ query_list,
+ db_autocommit=True,
)
- if claim_rows:
+
+ for user_id, device_id, algorithm, key_id, key_json in bulk_claims:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
- for claim_row in claim_rows:
- device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+ device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
+ unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1
+
# Did we get enough OTKs?
- count -= len(claim_rows)
- if count:
- missing.append((user_id, device_id, algorithm, count))
+ missing = [
+ (user, device, alg, count)
+ for (user, device, alg), count in unfulfilled_claim_counts.items()
+ if count > 0
+ ]
+ else:
+ for user_id, device_id, algorithm, count in query_list:
+ claim_rows = await self.db_pool.runInteraction(
+ "claim_e2e_one_time_keys",
+ self._claim_e2e_one_time_key_simple,
+ user_id,
+ device_id,
+ algorithm,
+ count,
+ db_autocommit=False,
+ )
+ if claim_rows:
+ device_results = results.setdefault(user_id, {}).setdefault(
+ device_id, {}
+ )
+ for claim_row in claim_rows:
+ device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+ # Did we get enough OTKs?
+ count -= len(claim_rows)
+ if count:
+ missing.append((user_id, device_id, algorithm, count))
return results, missing
@@ -1268,6 +1193,65 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
"""
+ if isinstance(self.database_engine, PostgresEngine):
+ return await self.db_pool.runInteraction(
+ "_claim_e2e_fallback_keys_bulk",
+ self._claim_e2e_fallback_keys_bulk_txn,
+ query_list,
+ db_autocommit=True,
+ )
+ # Use an UPDATE FROM... RETURNING combined with a VALUES block to do
+ # everything in one query. Note: this is also supported in SQLite 3.33.0,
+ # (see https://www.sqlite.org/lang_update.html#update_from), but we do not
+ # have an equivalent of psycopg2's execute_values to do this in one query.
+ else:
+ return await self._claim_e2e_fallback_keys_simple(query_list)
+
+ def _claim_e2e_fallback_keys_bulk_txn(
+ self,
+ txn: LoggingTransaction,
+ query_list: Iterable[Tuple[str, str, str, bool]],
+ ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+ """Efficient implementation of claim_e2e_fallback_keys for Postgres.
+
+ Safe to autocommit: this is a single query.
+ """
+ results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+
+ sql = """
+ WITH claims(user_id, device_id, algorithm, mark_as_used) AS (
+ VALUES ?
+ )
+ UPDATE e2e_fallback_keys_json k
+ SET used = used OR mark_as_used
+ FROM claims
+ WHERE (k.user_id, k.device_id, k.algorithm) = (claims.user_id, claims.device_id, claims.algorithm)
+ RETURNING k.user_id, k.device_id, k.algorithm, k.key_id, k.key_json;
+ """
+ claimed_keys = cast(
+ List[Tuple[str, str, str, str, str]],
+ txn.execute_values(sql, query_list),
+ )
+
+ seen_user_device: Set[Tuple[str, str]] = set()
+ for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
+ device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
+ device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
+
+ if (user_id, device_id) in seen_user_device:
+ continue
+ seen_user_device.add((user_id, device_id))
+ self._invalidate_cache_and_stream(
+ txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+ )
+
+ return results
+
+ async def _claim_e2e_fallback_keys_simple(
+ self,
+ query_list: Iterable[Tuple[str, str, str, bool]],
+ ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+ """Naive, inefficient implementation of claim_e2e_fallback_keys for SQLite."""
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, device_id, algorithm, mark_as_used in query_list:
row = await self.db_pool.simple_select_one(
@@ -1310,6 +1294,99 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return results
+ @trace
+ def _claim_e2e_one_time_key_simple(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ algorithm: str,
+ count: int,
+ ) -> List[Tuple[str, str]]:
+ """Claim OTK for device for DBs that don't support RETURNING.
+
+ Returns:
+ A tuple of key name (algorithm + key ID) and key JSON, if an
+ OTK was found.
+ """
+
+ sql = """
+ SELECT key_id, key_json FROM e2e_one_time_keys_json
+ WHERE user_id = ? AND device_id = ? AND algorithm = ?
+ LIMIT ?
+ """
+
+ txn.execute(sql, (user_id, device_id, algorithm, count))
+ otk_rows = list(txn)
+ if not otk_rows:
+ return []
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ column="key_id",
+ values=[otk_row[0] for otk_row in otk_rows],
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
+ return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]
+
+ @trace
+ def _claim_e2e_one_time_keys_bulk(
+ self,
+ txn: LoggingTransaction,
+ query_list: Iterable[Tuple[str, str, str, int]],
+ ) -> List[Tuple[str, str, str, str, str]]:
+ """Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.
+
+ Args:
+ query_list: Collection of tuples (user_id, device_id, algorithm, count)
+ as passed to claim_e2e_one_time_keys.
+
+ Returns:
+ A list of tuples (user_id, device_id, algorithm, key_id, key_json)
+ for each OTK claimed.
+ """
+ sql = """
+ WITH claims(user_id, device_id, algorithm, claim_count) AS (
+ VALUES ?
+ ), ranked_keys AS (
+ SELECT
+ user_id, device_id, algorithm, key_id, claim_count,
+ ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
+ FROM e2e_one_time_keys_json
+ JOIN claims USING (user_id, device_id, algorithm)
+ )
+ DELETE FROM e2e_one_time_keys_json k
+ WHERE (user_id, device_id, algorithm, key_id) IN (
+ SELECT user_id, device_id, algorithm, key_id
+ FROM ranked_keys
+ WHERE r <= claim_count
+ )
+ RETURNING user_id, device_id, algorithm, key_id, key_json;
+ """
+ otk_rows = cast(
+ List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
+ )
+
+ seen_user_device: Set[Tuple[str, str]] = set()
+ for user_id, device_id, _, _, _ in otk_rows:
+ if (user_id, device_id) in seen_user_device:
+ continue
+ seen_user_device.add((user_id, device_id))
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
+ return otk_rows
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def __init__(
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index d4251be7e7..b8bbd1eccd 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1048,15 +1048,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
Args:
event_ids: The event IDs to calculate the max depth of.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="events",
- column="event_id",
- iterable=event_ids,
- retcols=(
- "event_id",
- "depth",
+ rows = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="events",
+ column="event_id",
+ iterable=event_ids,
+ retcols=(
+ "event_id",
+ "depth",
+ ),
+ desc="get_max_depth_of",
),
- desc="get_max_depth_of",
)
if not rows:
@@ -1064,10 +1067,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
else:
max_depth_event_id = ""
current_max_depth = 0
- for row in rows:
- if row["depth"] > current_max_depth:
- max_depth_event_id = row["event_id"]
- current_max_depth = row["depth"]
+ for event_id, depth in rows:
+ if depth > current_max_depth:
+ max_depth_event_id = event_id
+ current_max_depth = depth
return max_depth_event_id, current_max_depth
@@ -1077,15 +1080,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
Args:
event_ids: The event IDs to calculate the max depth of.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="events",
- column="event_id",
- iterable=event_ids,
- retcols=(
- "event_id",
- "depth",
+ rows = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="events",
+ column="event_id",
+ iterable=event_ids,
+ retcols=(
+ "event_id",
+ "depth",
+ ),
+ desc="get_min_depth_of",
),
- desc="get_min_depth_of",
)
if not rows:
@@ -1093,10 +1099,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
else:
min_depth_event_id = ""
current_min_depth = MAX_DEPTH
- for row in rows:
- if row["depth"] < current_min_depth:
- min_depth_event_id = row["event_id"]
- current_min_depth = row["depth"]
+ for event_id, depth in rows:
+ if depth < current_min_depth:
+ min_depth_event_id = event_id
+ current_min_depth = depth
return min_depth_event_id, current_min_depth
@@ -1552,19 +1558,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
A filtered down list of `event_ids` that have previous failed pull attempts.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="event_failed_pull_attempts",
- column="event_id",
- iterable=event_ids,
- keyvalues={},
- retcols=("event_id",),
- desc="get_event_ids_with_failed_pull_attempts",
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="event_failed_pull_attempts",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=("event_id",),
+ desc="get_event_ids_with_failed_pull_attempts",
+ ),
)
- event_ids_with_failed_pull_attempts: Set[str] = {
- row["event_id"] for row in rows
- }
-
- return event_ids_with_failed_pull_attempts
+ return {row[0] for row in rows}
@trace
async def get_event_ids_to_not_pull_from_backoff(
@@ -1584,32 +1589,34 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
A dictionary of event_ids that should not be attempted to be pulled and the
next timestamp at which we may try pulling them again.
"""
- event_failed_pull_attempts = await self.db_pool.simple_select_many_batch(
- table="event_failed_pull_attempts",
- column="event_id",
- iterable=event_ids,
- keyvalues={},
- retcols=(
- "event_id",
- "last_attempt_ts",
- "num_attempts",
+ event_failed_pull_attempts = cast(
+ List[Tuple[str, int, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="event_failed_pull_attempts",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=(
+ "event_id",
+ "last_attempt_ts",
+ "num_attempts",
+ ),
+ desc="get_event_ids_to_not_pull_from_backoff",
),
- desc="get_event_ids_to_not_pull_from_backoff",
)
current_time = self._clock.time_msec()
event_ids_with_backoff = {}
- for event_failed_pull_attempt in event_failed_pull_attempts:
- event_id = event_failed_pull_attempt["event_id"]
+ for event_id, last_attempt_ts, num_attempts in event_failed_pull_attempts:
# Exponential back-off (up to the upper bound) so we don't try to
# pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
backoff_end_time = (
- event_failed_pull_attempt["last_attempt_ts"]
+ last_attempt_ts
+ (
2
** min(
- event_failed_pull_attempt["num_attempts"],
+ num_attempts,
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
)
)
@@ -1890,21 +1897,23 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# keeping only the forward extremities (i.e. the events not referenced
# by other events in the queue). We do this so that we can always
# backpaginate in all the events we have dropped.
- rows = await self.db_pool.simple_select_list(
- table="federation_inbound_events_staging",
- keyvalues={"room_id": room_id},
- retcols=("event_id", "event_json"),
- desc="prune_staged_events_in_room_fetch",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="federation_inbound_events_staging",
+ keyvalues={"room_id": room_id},
+ retcols=("event_id", "event_json"),
+ desc="prune_staged_events_in_room_fetch",
+ ),
)
# Find the set of events referenced by those in the queue, as well as
# collecting all the event IDs in the queue.
referenced_events: Set[str] = set()
seen_events: Set[str] = set()
- for row in rows:
- event_id = row["event_id"]
+ for event_id, event_json in rows:
seen_events.add(event_id)
- event_d = db_to_json(row["event_json"])
+ event_d = db_to_json(event_json)
# We don't bother parsing the dicts into full blown event objects,
# as that is needlessly expensive.
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index ed29d1fa5d..e4dc68c0d8 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -182,6 +182,7 @@ class UserPushAction(EmailPushAction):
profile_tag: str
+# TODO This is used as a cached value and is mutable.
@attr.s(slots=True, auto_attribs=True)
class NotifCounts:
"""
@@ -193,7 +194,7 @@ class NotifCounts:
highlight_count: int = 0
-@attr.s(slots=True, auto_attribs=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class RoomNotifCounts:
"""
The per-user, per-room count of notifications. Used by sync and push.
@@ -201,7 +202,7 @@ class RoomNotifCounts:
main_timeline: NotifCounts
# Map of thread ID to the notification counts.
- threads: Dict[str, NotifCounts]
+ threads: Mapping[str, NotifCounts]
@staticmethod
def empty() -> "RoomNotifCounts":
@@ -483,7 +484,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return room_to_count
- @cached(tree=True, max_entries=5000, iterable=True)
+ @cached(tree=True, max_entries=5000, iterable=True) # type: ignore[synapse-@cached-mutable]
async def get_unread_event_push_actions_by_room_for_user(
self,
room_id: str,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 790d058c43..7c34bde3e5 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -27,6 +27,7 @@ from typing import (
Optional,
Set,
Tuple,
+ Union,
cast,
)
@@ -78,7 +79,7 @@ class DeltaState:
Attributes:
to_delete: List of type/state_keys to delete from current state
to_insert: Map of state to upsert into current state
- no_longer_in_room: The server is not longer in the room, so the room
+ no_longer_in_room: The server is no longer in the room, so the room
should e.g. be removed from `current_state_events` table.
"""
@@ -130,22 +131,25 @@ class PersistEventsStore:
@trace
async def _persist_events_and_state_updates(
self,
+ room_id: str,
events_and_contexts: List[Tuple[EventBase, EventContext]],
*,
- state_delta_for_room: Dict[str, DeltaState],
- new_forward_extremities: Dict[str, Set[str]],
+ state_delta_for_room: Optional[DeltaState],
+ new_forward_extremities: Optional[Set[str]],
use_negative_stream_ordering: bool = False,
inhibit_local_membership_updates: bool = False,
) -> None:
"""Persist a set of events alongside updates to the current state and
- forward extremities tables.
+ forward extremities tables.
+
+ Assumes that we are only persisting events for one room at a time.
Args:
+ room_id:
events_and_contexts:
- state_delta_for_room: Map from room_id to the delta to apply to
- room state
- new_forward_extremities: Map from room_id to set of event IDs
- that are the new forward extremities of the room.
+ state_delta_for_room: The delta to apply to the room state
+ new_forward_extremities: A set of event IDs that are the new forward
+ extremities of the room.
use_negative_stream_ordering: Whether to start stream_ordering on
the negative side and decrement. This should be set as True
for backfilled events because backfilled events get a negative
@@ -195,6 +199,7 @@ class PersistEventsStore:
await self.db_pool.runInteraction(
"persist_events",
self._persist_events_txn,
+ room_id=room_id,
events_and_contexts=events_and_contexts,
inhibit_local_membership_updates=inhibit_local_membership_updates,
state_delta_for_room=state_delta_for_room,
@@ -220,9 +225,9 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
- for room_id, latest_event_ids in new_forward_extremities.items():
+ if new_forward_extremities:
self.store.get_latest_event_ids_in_room.prefill(
- (room_id,), frozenset(latest_event_ids)
+ (room_id,), frozenset(new_forward_extremities)
)
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
@@ -335,10 +340,11 @@ class PersistEventsStore:
self,
txn: LoggingTransaction,
*,
+ room_id: str,
events_and_contexts: List[Tuple[EventBase, EventContext]],
inhibit_local_membership_updates: bool,
- state_delta_for_room: Dict[str, DeltaState],
- new_forward_extremities: Dict[str, Set[str]],
+ state_delta_for_room: Optional[DeltaState],
+ new_forward_extremities: Optional[Set[str]],
) -> None:
"""Insert some number of room events into the necessary database tables.
@@ -346,8 +352,11 @@ class PersistEventsStore:
and the rejections table. Things reading from those table will need to check
whether the event was rejected.
+ Assumes that we are only persisting events for one room at a time.
+
Args:
txn
+ room_id: The room the events are from
events_and_contexts: events to persist
inhibit_local_membership_updates: Stop the local_current_membership
from being updated by these events. This should be set to True
@@ -356,10 +365,9 @@ class PersistEventsStore:
delete_existing True to purge existing table rows for the events
from the database. This is useful when retrying due to
IntegrityError.
- state_delta_for_room: The current-state delta for each room.
- new_forward_extremities: The new forward extremities for each room.
- For each room, a list of the event ids which are the forward
- extremities.
+ state_delta_for_room: The current-state delta for the room.
+ new_forward_extremities: The new forward extremities for the room:
+ a set of the event ids which are the forward extremities.
Raises:
PartialStateConflictError: if attempting to persist a partial state event in
@@ -375,14 +383,13 @@ class PersistEventsStore:
#
# Annoyingly SQLite doesn't support row level locking.
if isinstance(self.database_engine, PostgresEngine):
- for room_id in {e.room_id for e, _ in events_and_contexts}:
- txn.execute(
- "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE",
- (room_id,),
- )
- row = txn.fetchone()
- if row is None:
- raise Exception(f"Room does not exist {room_id}")
+ txn.execute(
+ "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE",
+ (room_id,),
+ )
+ row = txn.fetchone()
+ if row is None:
+ raise Exception(f"Room does not exist {room_id}")
# stream orderings should have been assigned by now
assert min_stream_order
@@ -418,7 +425,9 @@ class PersistEventsStore:
events_and_contexts
)
- self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
+ self._update_room_depths_txn(
+ txn, room_id, events_and_contexts=events_and_contexts
+ )
# _update_outliers_txn filters out any events which have already been
# persisted, and returns the filtered list.
@@ -431,11 +440,13 @@ class PersistEventsStore:
self._store_event_txn(txn, events_and_contexts=events_and_contexts)
- self._update_forward_extremities_txn(
- txn,
- new_forward_extremities=new_forward_extremities,
- max_stream_order=max_stream_order,
- )
+ if new_forward_extremities:
+ self._update_forward_extremities_txn(
+ txn,
+ room_id,
+ new_forward_extremities=new_forward_extremities,
+ max_stream_order=max_stream_order,
+ )
self._persist_transaction_ids_txn(txn, events_and_contexts)
@@ -463,7 +474,10 @@ class PersistEventsStore:
# We call this last as it assumes we've inserted the events into
# room_memberships, where applicable.
# NB: This function invalidates all state related caches
- self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+ if state_delta_for_room:
+ self._update_current_state_txn(
+ txn, room_id, state_delta_for_room, min_stream_order
+ )
def _persist_event_auth_chain_txn(
self,
@@ -501,16 +515,19 @@ class PersistEventsStore:
# We ignore legacy rooms that we aren't filling the chain cover index
# for.
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="rooms",
- column="room_id",
- iterable={event.room_id for event in events if event.is_state()},
- keyvalues={},
- retcols=("room_id", "has_auth_chain_index"),
+ rows = cast(
+ List[Tuple[str, Optional[Union[int, bool]]]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="rooms",
+ column="room_id",
+ iterable={event.room_id for event in events if event.is_state()},
+ keyvalues={},
+ retcols=("room_id", "has_auth_chain_index"),
+ ),
)
rooms_using_chain_index = {
- row["room_id"] for row in rows if row["has_auth_chain_index"]
+ room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
}
state_events = {
@@ -571,19 +588,18 @@ class PersistEventsStore:
# We check if there are any events that need to be handled in the rooms
# we're looking at. These should just be out of band memberships, where
# we didn't have the auth chain when we first persisted.
- rows = db_pool.simple_select_many_txn(
- txn,
- table="event_auth_chain_to_calculate",
- keyvalues={},
- column="room_id",
- iterable=set(event_to_room_id.values()),
- retcols=("event_id", "type", "state_key"),
+ auth_chain_to_calc_rows = cast(
+ List[Tuple[str, str, str]],
+ db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth_chain_to_calculate",
+ keyvalues={},
+ column="room_id",
+ iterable=set(event_to_room_id.values()),
+ retcols=("event_id", "type", "state_key"),
+ ),
)
- for row in rows:
- event_id = row["event_id"]
- event_type = row["type"]
- state_key = row["state_key"]
-
+ for event_id, event_type, state_key in auth_chain_to_calc_rows:
# (We could pull out the auth events for all rows at once using
# simple_select_many, but this case happens rarely and almost always
# with a single row.)
@@ -753,23 +769,31 @@ class PersistEventsStore:
# Step 1, fetch all existing links from all the chains we've seen
# referenced.
chain_links = _LinkMap()
- rows = db_pool.simple_select_many_txn(
- txn,
- table="event_auth_chain_links",
- column="origin_chain_id",
- iterable={chain_id for chain_id, _ in chain_map.values()},
- keyvalues={},
- retcols=(
- "origin_chain_id",
- "origin_sequence_number",
- "target_chain_id",
- "target_sequence_number",
+ auth_chain_rows = cast(
+ List[Tuple[int, int, int, int]],
+ db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth_chain_links",
+ column="origin_chain_id",
+ iterable={chain_id for chain_id, _ in chain_map.values()},
+ keyvalues={},
+ retcols=(
+ "origin_chain_id",
+ "origin_sequence_number",
+ "target_chain_id",
+ "target_sequence_number",
+ ),
),
)
- for row in rows:
+ for (
+ origin_chain_id,
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in auth_chain_rows:
chain_links.add_link(
- (row["origin_chain_id"], row["origin_sequence_number"]),
- (row["target_chain_id"], row["target_sequence_number"]),
+ (origin_chain_id, origin_sequence_number),
+ (target_chain_id, target_sequence_number),
new=False,
)
@@ -1015,74 +1039,75 @@ class PersistEventsStore:
await self.db_pool.runInteraction(
"update_current_state",
self._update_current_state_txn,
- state_delta_by_room={room_id: state_delta},
+ room_id,
+ delta_state=state_delta,
stream_id=stream_ordering,
)
def _update_current_state_txn(
self,
txn: LoggingTransaction,
- state_delta_by_room: Dict[str, DeltaState],
+ room_id: str,
+ delta_state: DeltaState,
stream_id: int,
) -> None:
- for room_id, delta_state in state_delta_by_room.items():
- to_delete = delta_state.to_delete
- to_insert = delta_state.to_insert
-
- # Figure out the changes of membership to invalidate the
- # `get_rooms_for_user` cache.
- # We find out which membership events we may have deleted
- # and which we have added, then we invalidate the caches for all
- # those users.
- members_changed = {
- state_key
- for ev_type, state_key in itertools.chain(to_delete, to_insert)
- if ev_type == EventTypes.Member
- }
+ to_delete = delta_state.to_delete
+ to_insert = delta_state.to_insert
+
+ # Figure out the changes of membership to invalidate the
+ # `get_rooms_for_user` cache.
+ # We find out which membership events we may have deleted
+ # and which we have added, then we invalidate the caches for all
+ # those users.
+ members_changed = {
+ state_key
+ for ev_type, state_key in itertools.chain(to_delete, to_insert)
+ if ev_type == EventTypes.Member
+ }
- if delta_state.no_longer_in_room:
- # Server is no longer in the room so we delete the room from
- # current_state_events, being careful we've already updated the
- # rooms.room_version column (which gets populated in a
- # background task).
- self._upsert_room_version_txn(txn, room_id)
+ if delta_state.no_longer_in_room:
+ # Server is no longer in the room so we delete the room from
+ # current_state_events, being careful we've already updated the
+ # rooms.room_version column (which gets populated in a
+ # background task).
+ self._upsert_room_version_txn(txn, room_id)
- # Before deleting we populate the current_state_delta_stream
- # so that async background tasks get told what happened.
- sql = """
+ # Before deleting we populate the current_state_delta_stream
+ # so that async background tasks get told what happened.
+ sql = """
INSERT INTO current_state_delta_stream
(stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
SELECT ?, ?, room_id, type, state_key, null, event_id
FROM current_state_events
WHERE room_id = ?
"""
- txn.execute(sql, (stream_id, self._instance_name, room_id))
+ txn.execute(sql, (stream_id, self._instance_name, room_id))
- # We also want to invalidate the membership caches for users
- # that were in the room.
- users_in_room = self.store.get_users_in_room_txn(txn, room_id)
- members_changed.update(users_in_room)
+ # We also want to invalidate the membership caches for users
+ # that were in the room.
+ users_in_room = self.store.get_users_in_room_txn(txn, room_id)
+ members_changed.update(users_in_room)
- self.db_pool.simple_delete_txn(
- txn,
- table="current_state_events",
- keyvalues={"room_id": room_id},
- )
- else:
- # We're still in the room, so we update the current state as normal.
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="current_state_events",
+ keyvalues={"room_id": room_id},
+ )
+ else:
+ # We're still in the room, so we update the current state as normal.
- # First we add entries to the current_state_delta_stream. We
- # do this before updating the current_state_events table so
- # that we can use it to calculate the `prev_event_id`. (This
- # allows us to not have to pull out the existing state
- # unnecessarily).
- #
- # The stream_id for the update is chosen to be the minimum of the stream_ids
- # for the batch of the events that we are persisting; that means we do not
- # end up in a situation where workers see events before the
- # current_state_delta updates.
- #
- sql = """
+ # First we add entries to the current_state_delta_stream. We
+ # do this before updating the current_state_events table so
+ # that we can use it to calculate the `prev_event_id`. (This
+ # allows us to not have to pull out the existing state
+ # unnecessarily).
+ #
+ # The stream_id for the update is chosen to be the minimum of the stream_ids
+ # for the batch of the events that we are persisting; that means we do not
+ # end up in a situation where workers see events before the
+ # current_state_delta updates.
+ #
+ sql = """
INSERT INTO current_state_delta_stream
(stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
SELECT ?, ?, ?, ?, ?, ?, (
@@ -1090,39 +1115,39 @@ class PersistEventsStore:
WHERE room_id = ? AND type = ? AND state_key = ?
)
"""
- txn.execute_batch(
- sql,
+ txn.execute_batch(
+ sql,
+ (
(
- (
- stream_id,
- self._instance_name,
- room_id,
- etype,
- state_key,
- to_insert.get((etype, state_key)),
- room_id,
- etype,
- state_key,
- )
- for etype, state_key in itertools.chain(to_delete, to_insert)
- ),
- )
- # Now we actually update the current_state_events table
+ stream_id,
+ self._instance_name,
+ room_id,
+ etype,
+ state_key,
+ to_insert.get((etype, state_key)),
+ room_id,
+ etype,
+ state_key,
+ )
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ ),
+ )
+ # Now we actually update the current_state_events table
- txn.execute_batch(
- "DELETE FROM current_state_events"
- " WHERE room_id = ? AND type = ? AND state_key = ?",
- (
- (room_id, etype, state_key)
- for etype, state_key in itertools.chain(to_delete, to_insert)
- ),
- )
+ txn.execute_batch(
+ "DELETE FROM current_state_events"
+ " WHERE room_id = ? AND type = ? AND state_key = ?",
+ (
+ (room_id, etype, state_key)
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ ),
+ )
- # We include the membership in the current state table, hence we do
- # a lookup when we insert. This assumes that all events have already
- # been inserted into room_memberships.
- txn.execute_batch(
- """INSERT INTO current_state_events
+ # We include the membership in the current state table, hence we do
+ # a lookup when we insert. This assumes that all events have already
+ # been inserted into room_memberships.
+ txn.execute_batch(
+ """INSERT INTO current_state_events
(room_id, type, state_key, event_id, membership, event_stream_ordering)
VALUES (
?, ?, ?, ?,
@@ -1130,34 +1155,34 @@ class PersistEventsStore:
(SELECT stream_ordering FROM events WHERE event_id = ?)
)
""",
- [
- (room_id, key[0], key[1], ev_id, ev_id, ev_id)
- for key, ev_id in to_insert.items()
- ],
- )
+ [
+ (room_id, key[0], key[1], ev_id, ev_id, ev_id)
+ for key, ev_id in to_insert.items()
+ ],
+ )
- # We now update `local_current_membership`. We do this regardless
- # of whether we're still in the room or not to handle the case where
- # e.g. we just got banned (where we need to record that fact here).
-
- # Note: Do we really want to delete rows here (that we do not
- # subsequently reinsert below)? While technically correct it means
- # we have no record of the fact the user *was* a member of the
- # room but got, say, state reset out of it.
- if to_delete or to_insert:
- txn.execute_batch(
- "DELETE FROM local_current_membership"
- " WHERE room_id = ? AND user_id = ?",
- (
- (room_id, state_key)
- for etype, state_key in itertools.chain(to_delete, to_insert)
- if etype == EventTypes.Member and self.is_mine_id(state_key)
- ),
- )
+ # We now update `local_current_membership`. We do this regardless
+ # of whether we're still in the room or not to handle the case where
+ # e.g. we just got banned (where we need to record that fact here).
- if to_insert:
- txn.execute_batch(
- """INSERT INTO local_current_membership
+ # Note: Do we really want to delete rows here (that we do not
+ # subsequently reinsert below)? While technically correct it means
+ # we have no record of the fact the user *was* a member of the
+ # room but got, say, state reset out of it.
+ if to_delete or to_insert:
+ txn.execute_batch(
+ "DELETE FROM local_current_membership"
+ " WHERE room_id = ? AND user_id = ?",
+ (
+ (room_id, state_key)
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ if etype == EventTypes.Member and self.is_mine_id(state_key)
+ ),
+ )
+
+ if to_insert:
+ txn.execute_batch(
+ """INSERT INTO local_current_membership
(room_id, user_id, event_id, membership, event_stream_ordering)
VALUES (
?, ?, ?,
@@ -1165,29 +1190,27 @@ class PersistEventsStore:
(SELECT stream_ordering FROM events WHERE event_id = ?)
)
""",
- [
- (room_id, key[1], ev_id, ev_id, ev_id)
- for key, ev_id in to_insert.items()
- if key[0] == EventTypes.Member and self.is_mine_id(key[1])
- ],
- )
-
- txn.call_after(
- self.store._curr_state_delta_stream_cache.entity_has_changed,
- room_id,
- stream_id,
+ [
+ (room_id, key[1], ev_id, ev_id, ev_id)
+ for key, ev_id in to_insert.items()
+ if key[0] == EventTypes.Member and self.is_mine_id(key[1])
+ ],
)
- # Invalidate the various caches
- self.store._invalidate_state_caches_and_stream(
- txn, room_id, members_changed
- )
+ txn.call_after(
+ self.store._curr_state_delta_stream_cache.entity_has_changed,
+ room_id,
+ stream_id,
+ )
- # Check if any of the remote membership changes requires us to
- # unsubscribe from their device lists.
- self.store.handle_potentially_left_users_txn(
- txn, {m for m in members_changed if not self.hs.is_mine_id(m)}
- )
+ # Invalidate the various caches
+ self.store._invalidate_state_caches_and_stream(txn, room_id, members_changed)
+
+ # Check if any of the remote membership changes requires us to
+ # unsubscribe from their device lists.
+ self.store.handle_potentially_left_users_txn(
+ txn, {m for m in members_changed if not self.hs.is_mine_id(m)}
+ )
def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
"""Update the room version in the database based off current state
@@ -1221,23 +1244,19 @@ class PersistEventsStore:
def _update_forward_extremities_txn(
self,
txn: LoggingTransaction,
- new_forward_extremities: Dict[str, Set[str]],
+ room_id: str,
+ new_forward_extremities: Set[str],
max_stream_order: int,
) -> None:
- for room_id in new_forward_extremities.keys():
- self.db_pool.simple_delete_txn(
- txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
- )
+ self.db_pool.simple_delete_txn(
+ txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
+ )
self.db_pool.simple_insert_many_txn(
txn,
table="event_forward_extremities",
keys=("event_id", "room_id"),
- values=[
- (ev_id, room_id)
- for room_id, new_extrem in new_forward_extremities.items()
- for ev_id in new_extrem
- ],
+ values=[(ev_id, room_id) for ev_id in new_forward_extremities],
)
# We now insert into stream_ordering_to_exterm a mapping from room_id,
# new stream_ordering to new forward extremeties in the room.
@@ -1249,8 +1268,7 @@ class PersistEventsStore:
keys=("room_id", "event_id", "stream_ordering"),
values=[
(room_id, event_id, max_stream_order)
- for room_id, new_extrem in new_forward_extremities.items()
- for event_id in new_extrem
+ for event_id in new_forward_extremities
],
)
@@ -1287,36 +1305,45 @@ class PersistEventsStore:
def _update_room_depths_txn(
self,
txn: LoggingTransaction,
+ room_id: str,
events_and_contexts: List[Tuple[EventBase, EventContext]],
) -> None:
"""Update min_depth for each room
Args:
txn: db connection
+ room_id: The room ID
events_and_contexts: events we are persisting
"""
- depth_updates: Dict[str, int] = {}
+ stream_ordering: Optional[int] = None
+ depth_update = 0
for event, context in events_and_contexts:
- # Then update the `stream_ordering` position to mark the latest
- # event as the front of the room. This should not be done for
- # backfilled events because backfilled events have negative
- # stream_ordering and happened in the past so we know that we don't
- # need to update the stream_ordering tip/front for the room.
+ # Don't update the stream ordering for backfilled events because
+ # backfilled events have negative stream_ordering and happened in the
+ # past, so we know that we don't need to update the stream_ordering
+ # tip/front for the room.
assert event.internal_metadata.stream_ordering is not None
if event.internal_metadata.stream_ordering >= 0:
- txn.call_after(
- self.store._events_stream_cache.entity_has_changed,
- event.room_id,
- event.internal_metadata.stream_ordering,
- )
+ if stream_ordering is None:
+ stream_ordering = event.internal_metadata.stream_ordering
+ else:
+ stream_ordering = max(
+ stream_ordering, event.internal_metadata.stream_ordering
+ )
if not event.internal_metadata.is_outlier() and not context.rejected:
- depth_updates[event.room_id] = max(
- event.depth, depth_updates.get(event.room_id, event.depth)
- )
+ depth_update = max(event.depth, depth_update)
- for room_id, depth in depth_updates.items():
- self._update_min_depth_for_room_txn(txn, room_id, depth)
+ # Then update the `stream_ordering` position to mark the latest event as
+ # the front of the room.
+ if stream_ordering is not None:
+ txn.call_after(
+ self.store._events_stream_cache.entity_has_changed,
+ room_id,
+ stream_ordering,
+ )
+
+ self._update_min_depth_for_room_txn(txn, room_id, depth_update)
def _update_outliers_txn(
self,
@@ -1339,13 +1366,19 @@ class PersistEventsStore:
PartialStateConflictError: if attempting to persist a partial state event in
a room that has been un-partial stated.
"""
- txn.execute(
- "SELECT event_id, outlier FROM events WHERE event_id in (%s)"
- % (",".join(["?"] * len(events_and_contexts)),),
- [event.event_id for event, _ in events_and_contexts],
+ rows = cast(
+ List[Tuple[str, bool]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ "events",
+ "event_id",
+ [event.event_id for event, _ in events_and_contexts],
+ keyvalues={},
+ retcols=("event_id", "outlier"),
+ ),
)
- have_persisted = dict(cast(Iterable[Tuple[str, bool]], txn))
+ have_persisted = dict(rows)
logger.debug(
"_update_outliers_txn: events=%s have_persisted=%s",
@@ -1443,7 +1476,7 @@ class PersistEventsStore:
txn,
table="event_json",
keys=("event_id", "room_id", "internal_metadata", "json", "format_version"),
- values=(
+ values=[
(
event.event_id,
event.room_id,
@@ -1452,7 +1485,7 @@ class PersistEventsStore:
event.format_version,
)
for event, _ in events_and_contexts
- ),
+ ],
)
self.db_pool.simple_insert_many_txn(
@@ -1475,7 +1508,7 @@ class PersistEventsStore:
"state_key",
"rejection_reason",
),
- values=(
+ values=[
(
self._instance_name,
event.internal_metadata.stream_ordering,
@@ -1494,7 +1527,7 @@ class PersistEventsStore:
context.rejected,
)
for event, context in events_and_contexts
- ),
+ ],
)
# If we're persisting an unredacted event we go and ensure
@@ -1517,11 +1550,11 @@ class PersistEventsStore:
txn,
table="state_events",
keys=("event_id", "room_id", "type", "state_key"),
- values=(
+ values=[
(event.event_id, event.room_id, event.type, event.state_key)
for event, _ in events_and_contexts
if event.is_state()
- ),
+ ],
)
def _store_rejected_events_txn(
@@ -1654,8 +1687,6 @@ class PersistEventsStore:
) -> None:
to_prefill = []
- rows = []
-
ev_map = {e.event_id: e for e, _ in events_and_contexts}
if not ev_map:
return
@@ -1676,10 +1707,9 @@ class PersistEventsStore:
)
txn.execute(sql + clause, args)
- rows = self.db_pool.cursor_to_dict(txn)
- for row in rows:
- event = ev_map[row["event_id"]]
- if not row["rejects"] and not row["redacts"]:
+ for event_id, redacts, rejects in txn:
+ event = ev_map[event_id]
+ if not rejects and not redacts:
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
async def external_prefill() -> None:
@@ -2259,35 +2289,59 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
- # From the events passed in, add all of the prev events as backwards extremities.
- # Ignore any events that are already backwards extrems or outliers.
- query = (
- "INSERT INTO event_backward_extremities (event_id, room_id)"
- " SELECT ?, ? WHERE NOT EXISTS ("
- " SELECT 1 FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- " )"
- # 1. Don't add an event as a extremity again if we already persisted it
- # as a non-outlier.
- # 2. Don't add an outlier as an extremity if it has no prev_events
- " AND NOT EXISTS ("
- " SELECT 1 FROM events"
- " LEFT JOIN event_edges edge"
- " ON edge.event_id = events.event_id"
- " WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = FALSE OR edge.event_id IS NULL)"
- " )"
+
+ room_id = events[0].room_id
+
+ potential_backwards_extremities = {
+ e_id
+ for ev in events
+ for e_id in ev.prev_event_ids()
+ if not ev.internal_metadata.is_outlier()
+ }
+
+ if not potential_backwards_extremities:
+ return
+
+ existing_events_outliers = self.db_pool.simple_select_many_txn(
+ txn,
+ table="events",
+ column="event_id",
+ iterable=potential_backwards_extremities,
+ keyvalues={"outlier": False},
+ retcols=("event_id",),
)
- txn.execute_batch(
- query,
- [
- (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id)
- for ev in events
- for e_id in ev.prev_event_ids()
- if not ev.internal_metadata.is_outlier()
- ],
+ potential_backwards_extremities.difference_update(
+ e for e, in existing_events_outliers
)
+ if potential_backwards_extremities:
+ self.db_pool.simple_upsert_many_txn(
+ txn,
+ table="event_backward_extremities",
+ key_names=("room_id", "event_id"),
+ key_values=[(room_id, ev) for ev in potential_backwards_extremities],
+ value_names=(),
+ value_values=(),
+ )
+
+ # Record the stream orderings where we have new gaps.
+ gap_events = [
+ (room_id, self._instance_name, ev.internal_metadata.stream_ordering)
+ for ev in events
+ if any(
+ e_id in potential_backwards_extremities
+ for e_id in ev.prev_event_ids()
+ )
+ ]
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="timeline_gaps",
+ keys=("room_id", "instance_name", "stream_ordering"),
+ values=gap_events,
+ )
+
# Delete all these events that we've already fetched and now know that their
# prev events are the new backwards extremeties.
query = (
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index daef3685b0..0061805150 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -369,18 +369,20 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
for chunk in chunks:
- ev_rows = self.db_pool.simple_select_many_txn(
- txn,
- table="event_json",
- column="event_id",
- iterable=chunk,
- retcols=["event_id", "json"],
- keyvalues={},
+ ev_rows = cast(
+ List[Tuple[str, str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="event_json",
+ column="event_id",
+ iterable=chunk,
+ retcols=["event_id", "json"],
+ keyvalues={},
+ ),
)
- for row in ev_rows:
- event_id = row["event_id"]
- event_json = db_to_json(row["json"])
+ for event_id, json in ev_rows:
+ event_json = db_to_json(json)
try:
origin_server_ts = event_json["origin_server_ts"]
except (KeyError, AttributeError):
@@ -563,15 +565,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if deleted:
# We now need to invalidate the caches of these rooms
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="events",
- column="event_id",
- iterable=to_delete,
- keyvalues={},
- retcols=("room_id",),
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="events",
+ column="event_id",
+ iterable=to_delete,
+ keyvalues={},
+ retcols=("room_id",),
+ ),
)
- room_ids = {row["room_id"] for row in rows}
+ room_ids = {row[0] for row in rows}
for room_id in room_ids:
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
@@ -1038,18 +1043,21 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
count = len(rows)
# We also need to fetch the auth events for them.
- auth_events = self.db_pool.simple_select_many_txn(
- txn,
- table="event_auth",
- column="event_id",
- iterable=event_to_room_id,
- keyvalues={},
- retcols=("event_id", "auth_id"),
+ auth_events = cast(
+ List[Tuple[str, str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth",
+ column="event_id",
+ iterable=event_to_room_id,
+ keyvalues={},
+ retcols=("event_id", "auth_id"),
+ ),
)
event_to_auth_chain: Dict[str, List[str]] = {}
- for row in auth_events:
- event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
+ for event_id, auth_id in auth_events:
+ event_to_auth_chain.setdefault(event_id, []).append(auth_id)
# Calculate and persist the chain cover index for this set of events.
#
@@ -1302,12 +1310,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
# indexes on it.
- # We need to pass execute a dummy function to handle the txn's result otherwise
- # it tries to call fetchall() on it and fails because there's no result to fetch.
- await self.db_pool.execute(
+ await self.db_pool.runInteraction(
"background_analyze_new_stream_ordering_column",
- lambda txn: None,
- "ANALYZE events(stream_ordering2)",
+ lambda txn: txn.execute("ANALYZE events(stream_ordering2)"),
)
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
index f851bff604..0ba84b1469 100644
--- a/synapse/storage/databases/main/events_forward_extremities.py
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List
+from typing import List, Optional, Tuple, cast
from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
@@ -91,12 +91,17 @@ class EventForwardExtremitiesStore(
async def get_forward_extremities_for_room(
self, room_id: str
- ) -> List[Dict[str, Any]]:
- """Get list of forward extremities for a room."""
+ ) -> List[Tuple[str, int, int, Optional[int]]]:
+ """
+ Get list of forward extremities for a room.
+
+ Returns:
+ A list of tuples of event_id, state_group, depth, and received_ts.
+ """
def get_forward_extremities_for_room_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, int, int, Optional[int]]]:
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
@@ -106,7 +111,7 @@ class EventForwardExtremitiesStore(
"""
txn.execute(sql, (room_id,))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_forward_extremities_for_room",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index b788d70fc5..5bf864c1fb 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1584,16 +1584,19 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="events",
- retcols=("event_id",),
- column="event_id",
- iterable=list(event_ids),
- keyvalues={"outlier": False},
- desc="have_events_in_timeline",
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="events",
+ retcols=("event_id",),
+ column="event_id",
+ iterable=list(event_ids),
+ keyvalues={"outlier": False},
+ desc="have_events_in_timeline",
+ ),
)
- return {r["event_id"] for r in rows}
+ return {r[0] for r in rows}
@trace
@tag_args
@@ -2093,12 +2096,6 @@ class EventsWorkerStore(SQLBaseStore):
def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
sql = """
- DELETE FROM event_txn_id
- WHERE inserted_ts < ?
- """
- txn.execute(sql, (one_day_ago,))
-
- sql = """
DELETE FROM event_txn_id_device_id
WHERE inserted_ts < ?
"""
@@ -2336,15 +2333,18 @@ class EventsWorkerStore(SQLBaseStore):
a dict mapping from event id to partial-stateness. We return True for
any of the events which are unknown (or are outliers).
"""
- result = await self.db_pool.simple_select_many_batch(
- table="partial_state_events",
- column="event_id",
- iterable=event_ids,
- retcols=["event_id"],
- desc="get_partial_state_events",
+ result = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="partial_state_events",
+ column="event_id",
+ iterable=event_ids,
+ retcols=["event_id"],
+ desc="get_partial_state_events",
+ ),
)
# convert the result to a dict, to make @cachedList work
- partial = {r["event_id"] for r in result}
+ partial = {r[0] for r in result}
return {e_id: e_id in partial for e_id in event_ids}
@cached()
diff --git a/synapse/storage/databases/main/experimental_features.py b/synapse/storage/databases/main/experimental_features.py
index 654f924019..60621edeef 100644
--- a/synapse/storage/databases/main/experimental_features.py
+++ b/synapse/storage/databases/main/experimental_features.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, FrozenSet
+from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main import CacheInvalidationWorkerStore
@@ -42,13 +42,16 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
Returns:
the features currently enabled for the user
"""
- enabled = await self.db_pool.simple_select_list(
- "per_user_experimental_features",
- {"user_id": user_id, "enabled": True},
- ["feature"],
+ enabled = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_list(
+ table="per_user_experimental_features",
+ keyvalues={"user_id": user_id, "enabled": True},
+ retcols=("feature",),
+ ),
)
- return frozenset(feature["feature"] for feature in enabled)
+ return frozenset(feature[0] for feature in enabled)
async def set_features_for_user(
self,
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 889c578b9c..ce88772f9e 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,7 +16,7 @@
import itertools
import json
import logging
-from typing import Dict, Iterable, Mapping, Optional, Tuple
+from typing import Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
@@ -205,35 +205,39 @@ class KeyStore(CacheInvalidationWorkerStore):
If we have multiple entries for a given key ID, returns the most recent.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="server_keys_json",
- column="key_id",
- iterable=key_ids,
- keyvalues={"server_name": server_name},
- retcols=(
- "key_id",
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "key_json",
+ rows = cast(
+ List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
+ await self.db_pool.simple_select_many_batch(
+ table="server_keys_json",
+ column="key_id",
+ iterable=key_ids,
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
),
- desc="get_server_keys_json_for_remote",
)
if not rows:
return {}
- # We sort the rows so that the most recently added entry is picked up.
- rows.sort(key=lambda r: r["ts_added_ms"])
+ # We sort the rows by ts_added_ms so that the most recently added entry
+ # will stomp over older entries in the dictionary.
+ rows.sort(key=lambda r: r[2])
return {
- row["key_id"]: FetchKeyResultForRemote(
+ key_id: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
- key_json=bytes(row["key_json"]),
- valid_until_ts=row["ts_valid_until_ms"],
- added_ts=row["ts_added_ms"],
+ key_json=bytes(key_json),
+ valid_until_ts=ts_valid_until_ms,
+ added_ts=ts_added_ms,
)
- for row in rows
+ for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
}
async def get_all_server_keys_json_for_remote(
@@ -244,30 +248,35 @@ class KeyStore(CacheInvalidationWorkerStore):
If we have multiple entries for a given key ID, returns the most recent.
"""
- rows = await self.db_pool.simple_select_list(
- table="server_keys_json",
- keyvalues={"server_name": server_name},
- retcols=(
- "key_id",
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "key_json",
+ rows = cast(
+ List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
+ await self.db_pool.simple_select_list(
+ table="server_keys_json",
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
),
- desc="get_server_keys_json_for_remote",
)
if not rows:
return {}
- rows.sort(key=lambda r: r["ts_added_ms"])
+ # We sort the rows by ts_added_ms so that the most recently added entry
+ # will stomp over older entries in the dictionary.
+ rows.sort(key=lambda r: r[2])
return {
- row["key_id"]: FetchKeyResultForRemote(
+ key_id: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
- key_json=bytes(row["key_json"]),
- valid_until_ts=row["ts_valid_until_ms"],
- added_ts=row["ts_added_ms"],
+ key_json=bytes(key_json),
+ valid_until_ts=ts_valid_until_ms,
+ added_ts=ts_added_ms,
)
- for row in rows
+ for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
}
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 8cebeb5189..c8d7c9fd32 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -26,8 +26,11 @@ from typing import (
cast,
)
+import attr
+
from synapse.api.constants import Direction
from synapse.logging.opentracing import trace
+from synapse.media._base import ThumbnailInfo
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -44,6 +47,18 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class LocalMedia:
+ media_id: str
+ media_type: str
+ media_length: int
+ upload_name: str
+ created_ts: int
+ last_access_ts: int
+ quarantined_by: Optional[str]
+ safe_from_quarantine: bool
+
+
class MediaSortOrder(Enum):
"""
Enum to define the sorting method used when returning media with
@@ -179,7 +194,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
user_id: str,
order_by: str = MediaSortOrder.CREATED_TS.value,
direction: Direction = Direction.FORWARDS,
- ) -> Tuple[List[Dict[str, Any]], int]:
+ ) -> Tuple[List[LocalMedia], int]:
"""Get a paginated list of metadata for a local piece of media
which an user_id has uploaded
@@ -196,7 +211,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def get_local_media_by_user_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[Dict[str, Any]], int]:
+ ) -> Tuple[List[LocalMedia], int]:
# Set ordering
order_by_column = MediaSortOrder(order_by).value
@@ -216,14 +231,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
sql = """
SELECT
- "media_id",
- "media_type",
- "media_length",
- "upload_name",
- "created_ts",
- "last_access_ts",
- "quarantined_by",
- "safe_from_quarantine"
+ media_id,
+ media_type,
+ media_length,
+ upload_name,
+ created_ts,
+ last_access_ts,
+ quarantined_by,
+ safe_from_quarantine
FROM local_media_repository
WHERE user_id = ?
ORDER BY {order_by_column} {order}, media_id ASC
@@ -235,7 +250,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
args += [limit, start]
txn.execute(sql, args)
- media = self.db_pool.cursor_to_dict(txn)
+ media = [
+ LocalMedia(
+ media_id=row[0],
+ media_type=row[1],
+ media_length=row[2],
+ upload_name=row[3],
+ created_ts=row[4],
+ last_access_ts=row[5],
+ quarantined_by=row[6],
+ safe_from_quarantine=bool(row[7]),
+ )
+ for row in txn
+ ]
return media, count
return await self.db_pool.runInteraction(
@@ -435,19 +462,28 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_url_cache",
)
- async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]:
- return await self.db_pool.simple_select_list(
- "local_media_repository_thumbnails",
- {"media_id": media_id},
- (
- "thumbnail_width",
- "thumbnail_height",
- "thumbnail_method",
- "thumbnail_type",
- "thumbnail_length",
+ async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]:
+ rows = cast(
+ List[Tuple[int, int, str, str, int]],
+ await self.db_pool.simple_select_list(
+ "local_media_repository_thumbnails",
+ {"media_id": media_id},
+ (
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ ),
+ desc="get_local_media_thumbnails",
),
- desc="get_local_media_thumbnails",
)
+ return [
+ ThumbnailInfo(
+ width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
+ )
+ for row in rows
+ ]
@trace
async def store_local_thumbnail(
@@ -556,20 +592,28 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_thumbnails(
self, origin: str, media_id: str
- ) -> List[Dict[str, Any]]:
- return await self.db_pool.simple_select_list(
- "remote_media_cache_thumbnails",
- {"media_origin": origin, "media_id": media_id},
- (
- "thumbnail_width",
- "thumbnail_height",
- "thumbnail_method",
- "thumbnail_type",
- "thumbnail_length",
- "filesystem_id",
+ ) -> List[ThumbnailInfo]:
+ rows = cast(
+ List[Tuple[int, int, str, str, int]],
+ await self.db_pool.simple_select_list(
+ "remote_media_cache_thumbnails",
+ {"media_origin": origin, "media_id": media_id},
+ (
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ ),
+ desc="get_remote_media_thumbnails",
),
- desc="get_remote_media_thumbnails",
)
+ return [
+ ThumbnailInfo(
+ width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
+ )
+ for row in rows
+ ]
@trace
async def get_remote_media_thumbnail(
@@ -632,7 +676,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_ids(
self, before_ts: int, include_quarantined_media: bool
- ) -> List[Dict[str, str]]:
+ ) -> List[Tuple[str, str, str]]:
"""
Retrieve a list of server name, media ID tuples from the remote media cache.
@@ -646,12 +690,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
A list of tuples containing:
* The server name of homeserver where the media originates from,
* The ID of the media.
+ * The filesystem ID.
+ """
+
+ sql = """
+ SELECT media_origin, media_id, filesystem_id
+ FROM remote_media_cache
+ WHERE last_access_ts < ?
"""
- sql = (
- "SELECT media_origin, media_id, filesystem_id"
- " FROM remote_media_cache"
- " WHERE last_access_ts < ?"
- )
if include_quarantined_media is False:
# Only include media that has not been quarantined
@@ -659,8 +705,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
AND quarantined_by IS NULL
"""
- return await self.db_pool.execute(
- "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
+ return cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.execute("get_remote_media_ids", sql, before_ts),
)
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 194b4e031f..3b444d2d07 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -20,6 +20,7 @@ from typing import (
Mapping,
Optional,
Tuple,
+ Union,
cast,
)
@@ -260,27 +261,40 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
async def get_presence_for_users(
self, user_ids: Iterable[str]
) -> Mapping[str, UserPresenceState]:
- rows = await self.db_pool.simple_select_many_batch(
- table="presence_stream",
- column="user_id",
- iterable=user_ids,
- keyvalues={},
- retcols=(
- "user_id",
- "state",
- "last_active_ts",
- "last_federation_update_ts",
- "last_user_sync_ts",
- "status_msg",
- "currently_active",
+ # TODO All these columns are nullable, but we don't expect that:
+ # https://github.com/matrix-org/synapse/issues/16467
+ rows = cast(
+ List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
+ await self.db_pool.simple_select_many_batch(
+ table="presence_stream",
+ column="user_id",
+ iterable=user_ids,
+ keyvalues={},
+ retcols=(
+ "user_id",
+ "state",
+ "last_active_ts",
+ "last_federation_update_ts",
+ "last_user_sync_ts",
+ "status_msg",
+ "currently_active",
+ ),
+ desc="get_presence_for_users",
),
- desc="get_presence_for_users",
)
- for row in rows:
- row["currently_active"] = bool(row["currently_active"])
-
- return {row["user_id"]: UserPresenceState(**row) for row in rows}
+ return {
+ user_id: UserPresenceState(
+ user_id=user_id,
+ state=state,
+ last_active_ts=last_active_ts,
+ last_federation_update_ts=last_federation_update_ts,
+ last_user_sync_ts=last_user_sync_ts,
+ status_msg=status_msg,
+ currently_active=bool(currently_active),
+ )
+ for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
+ }
async def should_user_receive_full_presence_with_token(
self,
@@ -385,28 +399,49 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
limit = 100
offset = 0
while True:
- rows = await self.db_pool.runInteraction(
- "get_presence_for_all_users",
- self.db_pool.simple_select_list_paginate_txn,
- "presence_stream",
- orderby="stream_id",
- start=offset,
- limit=limit,
- exclude_keyvalues=exclude_keyvalues,
- retcols=(
- "user_id",
- "state",
- "last_active_ts",
- "last_federation_update_ts",
- "last_user_sync_ts",
- "status_msg",
- "currently_active",
+ # TODO All these columns are nullable, but we don't expect that:
+ # https://github.com/matrix-org/synapse/issues/16467
+ rows = cast(
+ List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
+ await self.db_pool.runInteraction(
+ "get_presence_for_all_users",
+ self.db_pool.simple_select_list_paginate_txn,
+ "presence_stream",
+ orderby="stream_id",
+ start=offset,
+ limit=limit,
+ exclude_keyvalues=exclude_keyvalues,
+ retcols=(
+ "user_id",
+ "state",
+ "last_active_ts",
+ "last_federation_update_ts",
+ "last_user_sync_ts",
+ "status_msg",
+ "currently_active",
+ ),
+ order_direction="ASC",
),
- order_direction="ASC",
)
- for row in rows:
- users_to_state[row["user_id"]] = UserPresenceState(**row)
+ for (
+ user_id,
+ state,
+ last_active_ts,
+ last_federation_update_ts,
+ last_user_sync_ts,
+ status_msg,
+ currently_active,
+ ) in rows:
+ users_to_state[user_id] = UserPresenceState(
+ user_id=user_id,
+ state=state,
+ last_active_ts=last_active_ts,
+ last_federation_update_ts=last_federation_update_ts,
+ last_user_sync_ts=last_user_sync_ts,
+ status_msg=status_msg,
+ currently_active=bool(currently_active),
+ )
# We've run out of updates to query
if len(rows) < limit:
@@ -434,13 +469,21 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
txn.close()
- for row in rows:
- row["currently_active"] = bool(row["currently_active"])
-
- return [UserPresenceState(**row) for row in rows]
+ return [
+ UserPresenceState(
+ user_id=user_id,
+ state=state,
+ last_active_ts=last_active_ts,
+ last_federation_update_ts=last_federation_update_ts,
+ last_user_sync_ts=last_user_sync_ts,
+ status_msg=status_msg,
+ currently_active=bool(currently_active),
+ )
+ for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
+ ]
def take_presence_startup_info(self) -> List[UserPresenceState]:
active_on_startup = self._presence_on_startup
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index dea0e0458c..1e11bf2706 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -89,6 +89,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# furthermore, we might already have the table from a previous (failed)
# purge attempt, so let's drop the table first.
+ if isinstance(self.database_engine, PostgresEngine):
+ # Disable statement timeouts for this transaction; purging rooms can
+ # take a while!
+ txn.execute("SET LOCAL statement_timeout = 0")
+
txn.execute("DROP TABLE IF EXISTS events_to_purge")
txn.execute(
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 923166974c..37135d431d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,8 +28,11 @@ from typing import (
cast,
)
+from twisted.internet import defer
+
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -51,7 +54,8 @@ from synapse.storage.util.id_generators import (
)
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import json_encoder, unwrapFirstError
+from synapse.util.async_helpers import gather_results
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -62,20 +66,34 @@ logger = logging.getLogger(__name__)
def _load_rules(
- rawrules: List[JsonDict],
+ rawrules: List[Tuple[str, int, str, str]],
enabled_map: Dict[str, bool],
experimental_config: ExperimentalConfig,
) -> FilteredPushRules:
"""Take the DB rows returned from the DB and convert them into a full
`FilteredPushRules` object.
+
+ Args:
+ rawrules: List of tuples of:
+ * rule ID
+ * Priority lass
+ * Conditions (as serialized JSON)
+ * Actions (as serialized JSON)
+ enabled_map: A dictionary of rule ID to a boolean of whether the rule is
+ enabled. This might not include all rule IDs from rawrules.
+ experimental_config: The `experimental_features` section of the Synapse
+ config. (Used to check if various features are enabled.)
+
+ Returns:
+ A new FilteredPushRules object.
"""
ruleslist = [
PushRule.from_db(
- rule_id=rawrule["rule_id"],
- priority_class=rawrule["priority_class"],
- conditions=rawrule["conditions"],
- actions=rawrule["actions"],
+ rule_id=rawrule[0],
+ priority_class=rawrule[1],
+ conditions=rawrule[2],
+ actions=rawrule[3],
)
for rawrule in rawrules
]
@@ -165,34 +183,44 @@ class PushRulesWorkerStore(
@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
- rows = await self.db_pool.simple_select_list(
- table="push_rules",
- keyvalues={"user_name": user_id},
- retcols=(
- "user_name",
- "rule_id",
- "priority_class",
- "priority",
- "conditions",
- "actions",
+ rows = cast(
+ List[Tuple[str, int, int, str, str]],
+ await self.db_pool.simple_select_list(
+ table="push_rules",
+ keyvalues={"user_name": user_id},
+ retcols=(
+ "rule_id",
+ "priority_class",
+ "priority",
+ "conditions",
+ "actions",
+ ),
+ desc="get_push_rules_for_user",
),
- desc="get_push_rules_for_user",
)
- rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+ # Sort by highest priority_class, then highest priority.
+ rows.sort(key=lambda row: (-int(row[1]), -int(row[2])))
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
- return _load_rules(rows, enabled_map, self.hs.config.experimental)
+ return _load_rules(
+ [(row[0], row[1], row[3], row[4]) for row in rows],
+ enabled_map,
+ self.hs.config.experimental,
+ )
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
- results = await self.db_pool.simple_select_list(
- table="push_rules_enable",
- keyvalues={"user_name": user_id},
- retcols=("rule_id", "enabled"),
- desc="get_push_rules_enabled_for_user",
+ results = cast(
+ List[Tuple[str, Optional[Union[int, bool]]]],
+ await self.db_pool.simple_select_list(
+ table="push_rules_enable",
+ keyvalues={"user_name": user_id},
+ retcols=("rule_id", "enabled"),
+ desc="get_push_rules_enabled_for_user",
+ ),
)
- return {r["rule_id"]: bool(r["enabled"]) for r in results}
+ return {r[0]: bool(r[1]) for r in results}
async def have_push_rules_changed_for_user(
self, user_id: str, last_id: int
@@ -221,23 +249,46 @@ class PushRulesWorkerStore(
if not user_ids:
return {}
- raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
+ raw_rules: Dict[str, List[Tuple[str, int, str, str]]] = {
+ user_id: [] for user_id in user_ids
+ }
- rows = await self.db_pool.simple_select_many_batch(
- table="push_rules",
- column="user_name",
- iterable=user_ids,
- retcols=("*",),
- desc="bulk_get_push_rules",
- batch_size=1000,
+ # gatherResults loses all type information.
+ rows, enabled_map_by_user = await make_deferred_yieldable(
+ gather_results(
+ (
+ cast(
+ "defer.Deferred[List[Tuple[str, str, int, int, str, str]]]",
+ run_in_background(
+ self.db_pool.simple_select_many_batch,
+ table="push_rules",
+ column="user_name",
+ iterable=user_ids,
+ retcols=(
+ "user_name",
+ "rule_id",
+ "priority_class",
+ "priority",
+ "conditions",
+ "actions",
+ ),
+ desc="bulk_get_push_rules",
+ batch_size=1000,
+ ),
+ ),
+ run_in_background(self.bulk_get_push_rules_enabled, user_ids),
+ ),
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
)
- rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
-
- for row in rows:
- raw_rules.setdefault(row["user_name"], []).append(row)
+ # Sort by highest priority_class, then highest priority.
+ rows.sort(key=lambda row: (-int(row[2]), -int(row[3])))
- enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
+ for user_name, rule_id, priority_class, _, conditions, actions in rows:
+ raw_rules.setdefault(user_name, []).append(
+ (rule_id, priority_class, conditions, actions)
+ )
results: Dict[str, FilteredPushRules] = {}
@@ -256,17 +307,19 @@ class PushRulesWorkerStore(
results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
- rows = await self.db_pool.simple_select_many_batch(
- table="push_rules_enable",
- column="user_name",
- iterable=user_ids,
- retcols=("user_name", "rule_id", "enabled"),
- desc="bulk_get_push_rules_enabled",
- batch_size=1000,
+ rows = cast(
+ List[Tuple[str, str, Optional[int]]],
+ await self.db_pool.simple_select_many_batch(
+ table="push_rules_enable",
+ column="user_name",
+ iterable=user_ids,
+ retcols=("user_name", "rule_id", "enabled"),
+ desc="bulk_get_push_rules_enabled",
+ batch_size=1000,
+ ),
)
- for row in rows:
- enabled = bool(row["enabled"])
- results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
+ for user_name, rule_id, enabled in rows:
+ results.setdefault(user_name, {})[rule_id] = bool(enabled)
return results
async def get_all_push_rule_updates(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 87e28e22d3..a6a1671bd6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -47,6 +47,27 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# The type of a row in the pushers table.
+PusherRow = Tuple[
+ int, # id
+ str, # user_name
+ Optional[int], # access_token
+ str, # profile_tag
+ str, # kind
+ str, # app_id
+ str, # app_display_name
+ str, # device_display_name
+ str, # pushkey
+ int, # ts
+ str, # lang
+ str, # data
+ int, # last_stream_ordering
+ int, # last_success
+ int, # failing_since
+ bool, # enabled
+ str, # device_id
+]
+
class PusherWorkerStore(SQLBaseStore):
def __init__(
@@ -83,30 +104,66 @@ class PusherWorkerStore(SQLBaseStore):
self._remove_deleted_email_pushers,
)
- def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
+ def _decode_pushers_rows(
+ self,
+ rows: Iterable[PusherRow],
+ ) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
Drops any rows whose data cannot be decoded
"""
- for r in rows:
- data_json = r["data"]
+ for (
+ id,
+ user_name,
+ access_token,
+ profile_tag,
+ kind,
+ app_id,
+ app_display_name,
+ device_display_name,
+ pushkey,
+ ts,
+ lang,
+ data,
+ last_stream_ordering,
+ last_success,
+ failing_since,
+ enabled,
+ device_id,
+ ) in rows:
try:
- r["data"] = db_to_json(data_json)
+ data_json = db_to_json(data)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
- r["id"],
- data_json,
+ id,
+ data,
e.args[0],
)
continue
- # If we're using SQLite, then boolean values are integers. This is
- # troublesome since some code using the return value of this method might
- # expect it to be a boolean, or will expose it to clients (in responses).
- r["enabled"] = bool(r["enabled"])
-
- yield PusherConfig(**r)
+ yield PusherConfig(
+ id=id,
+ user_name=user_name,
+ profile_tag=profile_tag,
+ kind=kind,
+ app_id=app_id,
+ app_display_name=app_display_name,
+ device_display_name=device_display_name,
+ pushkey=pushkey,
+ ts=ts,
+ lang=lang,
+ data=data_json,
+ last_stream_ordering=last_stream_ordering,
+ last_success=last_success,
+ failing_since=failing_since,
+ # If we're using SQLite, then boolean values are integers. This is
+ # troublesome since some code using the return value of this method might
+ # expect it to be a boolean, or will expose it to clients (in responses).
+ enabled=bool(enabled),
+ device_id=device_id,
+ access_token=access_token,
+ )
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
@@ -136,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
The pushers for which the given columns have the given values.
"""
- def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def get_pushers_by_txn(txn: LoggingTransaction) -> List[PusherRow]:
# We could technically use simple_select_list here, but we need to call
# COALESCE on the 'enabled' column. While it is technically possible to give
# simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
@@ -154,7 +211,7 @@ class PusherWorkerStore(SQLBaseStore):
txn.execute(sql, list(keyvalues.values()))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[PusherRow], txn.fetchall())
ret = await self.db_pool.runInteraction(
desc="get_pushers_by",
@@ -164,14 +221,22 @@ class PusherWorkerStore(SQLBaseStore):
return self._decode_pushers_rows(ret)
async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
- def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
- txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
- rows = self.db_pool.cursor_to_dict(txn)
-
- return self._decode_pushers_rows(rows)
+ def get_enabled_pushers_txn(txn: LoggingTransaction) -> List[PusherRow]:
+ txn.execute(
+ """
+ SELECT id, user_name, access_token, profile_tag, kind, app_id,
+ app_display_name, device_display_name, pushkey, ts, lang, data,
+ last_stream_ordering, last_success, failing_since,
+ enabled, device_id
+ FROM pushers WHERE COALESCE(enabled, TRUE)
+ """
+ )
+ return cast(List[PusherRow], txn.fetchall())
- return await self.db_pool.runInteraction(
- "get_enabled_pushers", get_enabled_pushers_txn
+ return self._decode_pushers_rows(
+ await self.db_pool.runInteraction(
+ "get_enabled_pushers", get_enabled_pushers_txn
+ )
)
async def get_all_updated_pushers_rows(
@@ -304,26 +369,28 @@ class PusherWorkerStore(SQLBaseStore):
)
async def get_throttle_params_by_room(
- self, pusher_id: str
+ self, pusher_id: int
) -> Dict[str, ThrottleParams]:
- res = await self.db_pool.simple_select_list(
- "pusher_throttle",
- {"pusher": pusher_id},
- ["room_id", "last_sent_ts", "throttle_ms"],
- desc="get_throttle_params_by_room",
+ res = cast(
+ List[Tuple[str, Optional[int], Optional[int]]],
+ await self.db_pool.simple_select_list(
+ "pusher_throttle",
+ {"pusher": pusher_id},
+ ["room_id", "last_sent_ts", "throttle_ms"],
+ desc="get_throttle_params_by_room",
+ ),
)
params_by_room = {}
- for row in res:
- params_by_room[row["room_id"]] = ThrottleParams(
- row["last_sent_ts"],
- row["throttle_ms"],
+ for room_id, last_sent_ts, throttle_ms in res:
+ params_by_room[room_id] = ThrottleParams(
+ last_sent_ts or 0, throttle_ms or 0
)
return params_by_room
async def set_throttle_params(
- self, pusher_id: str, room_id: str, params: ThrottleParams
+ self, pusher_id: int, room_id: str, params: ThrottleParams
) -> None:
await self.db_pool.simple_upsert(
"pusher_throttle",
@@ -534,7 +601,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
(last_pusher_id, batch_size),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if len(rows) == 0:
return 0
@@ -550,19 +617,19 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="pushers",
key_names=("id",),
- key_values=[(row["pusher_id"],) for row in rows],
+ key_values=[row[0] for row in rows],
value_names=("device_id", "access_token"),
# If there was already a device_id on the pusher, we only want to clear
# the access_token column, so we keep the existing device_id. Otherwise,
# we set the device_id we got from joining the access_tokens table.
value_values=[
- (row["pusher_device_id"] or row["token_device_id"], None)
- for row in rows
+ (pusher_device_id or token_device_id, None)
+ for _, pusher_device_id, token_device_id in rows
],
)
self.db_pool.updates._background_update_progress_txn(
- txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["pusher_id"]}
+ txn, "set_device_id_for_pushers", {"pusher_id": rows[-1][0]}
)
return len(rows)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 0231f9407b..56e8eb16a8 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -28,6 +28,8 @@ from typing import (
cast,
)
+from immutabledict import immutabledict
+
from synapse.api.constants import EduTypes
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -43,7 +45,12 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
-from synapse.types import JsonDict, JsonMapping
+from synapse.types import (
+ JsonDict,
+ JsonMapping,
+ MultiWriterStreamToken,
+ PersistedPosition,
+)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -105,7 +112,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"receipts_linearized",
entity_column="room_id",
stream_column="stream_id",
- max_value=max_receipts_stream_id,
+ max_value=max_receipts_stream_id.stream,
limit=10000,
)
self._receipts_stream_cache = StreamChangeCache(
@@ -114,9 +121,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
prefilled_cache=receipts_stream_prefill,
)
- def get_max_receipt_stream_id(self) -> int:
+ def get_max_receipt_stream_id(self) -> MultiWriterStreamToken:
"""Get the current max stream ID for receipts stream"""
- return self._receipts_id_gen.get_current_token()
+
+ min_pos = self._receipts_id_gen.get_current_token()
+
+ positions = {}
+ if isinstance(self._receipts_id_gen, MultiWriterIdGenerator):
+ # The `min_pos` is the minimum position that we know all instances
+ # have finished persisting to, so we only care about instances whose
+ # positions are ahead of that. (Instance positions can be behind the
+ # min position as there are times we can work out that the minimum
+ # position is ahead of the naive minimum across all current
+ # positions. See MultiWriterIdGenerator for details)
+ positions = {
+ i: p
+ for i, p in self._receipts_id_gen.get_positions().items()
+ if p > min_pos
+ }
+
+ return MultiWriterStreamToken(
+ stream=min_pos, instance_map=immutabledict(positions)
+ )
+
+ def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
+ return self._receipts_id_gen.get_current_token_for_writer(instance_name)
def get_last_unthreaded_receipt_for_user_txn(
self,
@@ -257,7 +286,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
async def get_linearized_receipts_for_rooms(
- self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
+ self,
+ room_ids: Iterable[str],
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> List[JsonMapping]:
"""Get receipts for multiple rooms for sending to clients.
@@ -276,7 +308,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
room_ids = self._receipts_stream_cache.get_entities_changed(
- room_ids, from_key
+ room_ids, from_key.stream
)
results = await self._get_linearized_receipts_for_rooms(
@@ -286,7 +318,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return [ev for res in results.values() for ev in res]
async def get_linearized_receipts_for_room(
- self, room_id: str, to_key: int, from_key: Optional[int] = None
+ self,
+ room_id: str,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Sequence[JsonMapping]:
"""Get receipts for a single room for sending to clients.
@@ -302,36 +337,49 @@ class ReceiptsWorkerStore(SQLBaseStore):
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
- if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
+ if not self._receipts_stream_cache.has_entity_changed(
+ room_id, from_key.stream
+ ):
return []
return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
@cached(tree=True)
async def _get_linearized_receipts_for_room(
- self, room_id: str, to_key: int, from_key: Optional[int] = None
+ self,
+ room_id: str,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Sequence[JsonMapping]:
"""See get_linearized_receipts_for_room"""
- def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
if from_key:
- sql = (
- "SELECT * FROM receipts_linearized WHERE"
- " room_id = ? AND stream_id > ? AND stream_id <= ?"
- )
+ sql = """
+ SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+ FROM receipts_linearized
+ WHERE room_id = ? AND stream_id > ? AND stream_id <= ?
+ """
- txn.execute(sql, (room_id, from_key, to_key))
- else:
- sql = (
- "SELECT * FROM receipts_linearized WHERE"
- " room_id = ? AND stream_id <= ?"
+ txn.execute(
+ sql, (room_id, from_key.stream, to_key.get_max_stream_pos())
)
+ else:
+ sql = """
+ SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
+ room_id = ? AND stream_id <= ?
+ """
- txn.execute(sql, (room_id, to_key))
-
- rows = self.db_pool.cursor_to_dict(txn)
+ txn.execute(sql, (room_id, to_key.get_max_stream_pos()))
- return rows
+ return [
+ (receipt_type, user_id, event_id, data)
+ for stream_id, instance_name, receipt_type, user_id, event_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
@@ -339,10 +387,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return []
content: JsonDict = {}
- for row in rows:
- content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
- row["user_id"]
- ] = db_to_json(row["data"])
+ for receipt_type, user_id, event_id, data in rows:
+ content.setdefault(event_id, {}).setdefault(receipt_type, {})[
+ user_id
+ ] = db_to_json(data)
return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
@@ -352,25 +400,37 @@ class ReceiptsWorkerStore(SQLBaseStore):
num_args=3,
)
async def _get_linearized_receipts_for_rooms(
- self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
+ self,
+ room_ids: Collection[str],
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Mapping[str, Sequence[JsonMapping]]:
if not room_ids:
return {}
- def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def f(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
if from_key:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT stream_id, instance_name, room_id, receipt_type,
+ user_id, event_id, thread_id, data
+ FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND
"""
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", room_ids
)
- txn.execute(sql + clause, [from_key, to_key] + list(args))
+ txn.execute(
+ sql + clause,
+ [from_key.stream, to_key.get_max_stream_pos()] + list(args),
+ )
else:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT stream_id, instance_name, room_id, receipt_type,
+ user_id, event_id, thread_id, data
+ FROM receipts_linearized WHERE
stream_id <= ? AND
"""
@@ -378,31 +438,37 @@ class ReceiptsWorkerStore(SQLBaseStore):
self.database_engine, "room_id", room_ids
)
- txn.execute(sql + clause, [to_key] + list(args))
+ txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args))
- return self.db_pool.cursor_to_dict(txn)
+ return [
+ (room_id, receipt_type, user_id, event_id, thread_id, data)
+ for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
results: JsonDict = {}
- for row in txn_results:
+ for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
- row["room_id"],
- {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
+ room_id,
+ {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
- event_entry = room_event["content"].setdefault(row["event_id"], {})
- receipt_type = event_entry.setdefault(row["receipt_type"], {})
+ event_entry = room_event["content"].setdefault(event_id, {})
+ receipt_type_dict = event_entry.setdefault(receipt_type, {})
- receipt_type[row["user_id"]] = db_to_json(row["data"])
- if row["thread_id"]:
- receipt_type[row["user_id"]]["thread_id"] = row["thread_id"]
+ receipt_type_dict[user_id] = db_to_json(data)
+ if thread_id:
+ receipt_type_dict[user_id]["thread_id"] = thread_id
results = {
room_id: [results[room_id]] if room_id in results else []
@@ -414,7 +480,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
num_args=2,
)
async def get_linearized_receipts_for_all_rooms(
- self, to_key: int, from_key: Optional[int] = None
+ self,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Mapping[str, JsonMapping]:
"""Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts.
@@ -428,46 +496,54 @@ class ReceiptsWorkerStore(SQLBaseStore):
A dictionary of roomids to a list of receipts.
"""
- def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
if from_key:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
"""
- txn.execute(sql, [from_key, to_key])
+ txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()])
else:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
"""
- txn.execute(sql, [to_key])
+ txn.execute(sql, [to_key.get_max_stream_pos()])
- return self.db_pool.cursor_to_dict(txn)
+ return [
+ (room_id, receipt_type, user_id, event_id, data)
+ for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
txn_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_all_rooms", f
)
results: JsonDict = {}
- for row in txn_results:
+ for room_id, receipt_type, user_id, event_id, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
- row["room_id"],
- {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
+ room_id,
+ {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
- event_entry = room_event["content"].setdefault(row["event_id"], {})
- receipt_type = event_entry.setdefault(row["receipt_type"], {})
+ event_entry = room_event["content"].setdefault(event_id, {})
+ receipt_type_dict = event_entry.setdefault(receipt_type, {})
- receipt_type[row["user_id"]] = db_to_json(row["data"])
+ receipt_type_dict[user_id] = db_to_json(data)
return results
@@ -537,10 +613,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ?
+ AND instance_name = ?
ORDER BY stream_id ASC
LIMIT ?
"""
- txn.execute(sql, (last_id, current_id, limit))
+ txn.execute(sql, (last_id, current_id, instance_name, limit))
updates = cast(
List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
@@ -687,6 +764,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
keyvalues=keyvalues,
values={
"stream_id": stream_id,
+ "instance_name": self._instance_name,
"event_id": event_id,
"event_stream_ordering": stream_ordering,
"data": json_encoder.encode(data),
@@ -742,7 +820,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_ids: List[str],
thread_id: Optional[str],
data: dict,
- ) -> Optional[Tuple[int, int]]:
+ ) -> Optional[PersistedPosition]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
@@ -804,9 +882,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
data,
)
- max_persisted_id = self._receipts_id_gen.get_current_token()
-
- return stream_id, max_persisted_id
+ return PersistedPosition(self._instance_name, stream_id)
async def _insert_graph_receipt(
self,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index cc964604e2..933d76e905 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -143,6 +143,30 @@ class LoginTokenLookupResult:
"""The session ID advertised by the SSO Identity Provider."""
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class ThreepidResult:
+ medium: str
+ address: str
+ validated_at: int
+ added_at: int
+
+
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class ThreepidValidationSession:
+ address: str
+ """address of the 3pid"""
+ medium: str
+ """medium of the 3pid"""
+ client_secret: str
+ """a secret provided by the client for this validation session"""
+ session_id: str
+ """ID of the validation session"""
+ last_send_attempt: int
+ """a number serving to dedupe send attempts for this session"""
+ validated_at: Optional[int]
+ """timestamp of when this session was validated if so"""
+
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -195,7 +219,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Returns info about the user account, if it exists."""
- def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]:
# We could technically use simple_select_one here, but it would not perform
# the COALESCEs (unless hacked into the column names), which could yield
# confusing results.
@@ -213,35 +237,46 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,),
)
- rows = self.db_pool.cursor_to_dict(txn)
-
- if len(rows) == 0:
+ row = txn.fetchone()
+ if not row:
return None
- return rows[0]
+ (
+ name,
+ is_guest,
+ admin,
+ consent_version,
+ consent_ts,
+ consent_server_notice_sent,
+ appservice_id,
+ creation_ts,
+ user_type,
+ deactivated,
+ shadow_banned,
+ approved,
+ locked,
+ ) = row
+
+ return UserInfo(
+ appservice_id=appservice_id,
+ consent_server_notice_sent=consent_server_notice_sent,
+ consent_version=consent_version,
+ consent_ts=consent_ts,
+ creation_ts=creation_ts,
+ is_admin=bool(admin),
+ is_deactivated=bool(deactivated),
+ is_guest=bool(is_guest),
+ is_shadow_banned=bool(shadow_banned),
+ user_id=UserID.from_string(name),
+ user_type=user_type,
+ approved=bool(approved),
+ locked=bool(locked),
+ )
- row = await self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
desc="get_user_by_id",
func=get_user_by_id_txn,
)
- if row is None:
- return None
-
- return UserInfo(
- appservice_id=row["appservice_id"],
- consent_server_notice_sent=row["consent_server_notice_sent"],
- consent_version=row["consent_version"],
- consent_ts=row["consent_ts"],
- creation_ts=row["creation_ts"],
- is_admin=bool(row["admin"]),
- is_deactivated=bool(row["deactivated"]),
- is_guest=bool(row["is_guest"]),
- is_shadow_banned=bool(row["shadow_banned"]),
- user_id=UserID.from_string(row["name"]),
- user_type=row["user_type"],
- approved=bool(row["approved"]),
- locked=bool(row["locked"]),
- )
async def is_trial_user(self, user_id: str) -> bool:
"""Checks if user is in the "trial" period, i.e. within the first
@@ -579,16 +614,31 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""
txn.execute(sql, (token,))
- rows = self.db_pool.cursor_to_dict(txn)
-
- if rows:
- row = rows[0]
-
- # This field is nullable, ensure it comes out as a boolean
- if row["token_used"] is None:
- row["token_used"] = False
+ row = txn.fetchone()
- return TokenLookupResult(**row)
+ if row:
+ (
+ user_id,
+ is_guest,
+ shadow_banned,
+ token_id,
+ device_id,
+ valid_until_ms,
+ token_owner,
+ token_used,
+ ) = row
+
+ return TokenLookupResult(
+ user_id=user_id,
+ is_guest=is_guest,
+ shadow_banned=shadow_banned,
+ token_id=token_id,
+ device_id=device_id,
+ valid_until_ms=valid_until_ms,
+ token_owner=token_owner,
+ # This field is nullable, ensure it comes out as a boolean
+ token_used=bool(token_used),
+ )
return None
@@ -821,23 +871,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
Tuples of (auth_provider, external_id)
"""
- res = await self.db_pool.simple_select_list(
- table="user_external_ids",
- keyvalues={"user_id": mxid},
- retcols=("auth_provider", "external_id"),
- desc="get_external_ids_by_user",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="user_external_ids",
+ keyvalues={"user_id": mxid},
+ retcols=("auth_provider", "external_id"),
+ desc="get_external_ids_by_user",
+ ),
)
- return [(r["auth_provider"], r["external_id"]) for r in res]
async def count_all_users(self) -> int:
"""Counts all users registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int:
- txn.execute("SELECT COUNT(*) AS users FROM users")
- rows = self.db_pool.cursor_to_dict(txn)
- if rows:
- return rows[0]["users"]
- return 0
+ txn.execute("SELECT COUNT(*) FROM users")
+ row = txn.fetchone()
+ assert row is not None
+ return row[0]
return await self.db_pool.runInteraction("count_users", _count_users)
@@ -891,11 +942,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Counts all users without a special user_type registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int:
- txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
- rows = self.db_pool.cursor_to_dict(txn)
- if rows:
- return rows[0]["users"]
- return 0
+ txn.execute("SELECT COUNT(*) FROM users where user_type is null")
+ row = txn.fetchone()
+ assert row is not None
+ return row[0]
return await self.db_pool.runInteraction("count_real_users", _count_users)
@@ -964,13 +1014,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
- async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
- return await self.db_pool.simple_select_list(
- "user_threepids",
- {"user_id": user_id},
- ["medium", "address", "validated_at", "added_at"],
- "user_get_threepids",
+ async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
+ results = cast(
+ List[Tuple[str, str, int, int]],
+ await self.db_pool.simple_select_list(
+ "user_threepids",
+ keyvalues={"user_id": user_id},
+ retcols=["medium", "address", "validated_at", "added_at"],
+ desc="user_get_threepids",
+ ),
)
+ return [
+ ThreepidResult(
+ medium=r[0],
+ address=r[1],
+ validated_at=r[2],
+ added_at=r[3],
+ )
+ for r in results
+ ]
async def user_delete_threepid(
self, user_id: str, medium: str, address: str
@@ -1009,7 +1071,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="add_user_bound_threepid",
)
- async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
+ async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]:
"""Get the threepids that a user has bound to an identity server through the homeserver
The homeserver remembers where binds to an identity server occurred. Using this
method can retrieve those threepids.
@@ -1018,15 +1080,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_id: The ID of the user to retrieve threepids for
Returns:
- List of dictionaries containing the following keys:
- medium (str): The medium of the threepid (e.g "email")
- address (str): The address of the threepid (e.g "bob@example.com")
- """
- return await self.db_pool.simple_select_list(
- table="user_threepid_id_server",
- keyvalues={"user_id": user_id},
- retcols=["medium", "address"],
- desc="user_get_bound_threepids",
+ List of tuples of two strings:
+ medium: The medium of the threepid (e.g "email")
+ address: The address of the threepid (e.g "bob@example.com")
+ """
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="user_threepid_id_server",
+ keyvalues={"user_id": user_id},
+ retcols=["medium", "address"],
+ desc="user_get_bound_threepids",
+ ),
)
async def remove_user_bound_threepid(
@@ -1123,7 +1188,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
address: Optional[str] = None,
sid: Optional[str] = None,
validated: Optional[bool] = True,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[ThreepidValidationSession]:
"""Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata
@@ -1138,15 +1203,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
perform no filtering
Returns:
- A dict containing the following:
- * address - address of the 3pid
- * medium - medium of the 3pid
- * client_secret - a secret provided by the client for this validation session
- * session_id - ID of the validation session
- * send_attempt - a number serving to dedupe send attempts for this session
- * validated_at - timestamp of when this session was validated if so
-
- Otherwise None if a validation session is not found
+ A ThreepidValidationSession or None if a validation session is not found
"""
if not client_secret:
raise SynapseError(
@@ -1165,7 +1222,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def get_threepid_validation_session_txn(
txn: LoggingTransaction,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[ThreepidValidationSession]:
sql = """
SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at
@@ -1180,11 +1237,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
sql += " LIMIT 1"
txn.execute(sql, list(keyvalues.values()))
- rows = self.db_pool.cursor_to_dict(txn)
- if not rows:
+ row = txn.fetchone()
+ if not row:
return None
- return rows[0]
+ return ThreepidValidationSession(
+ address=row[0],
+ session_id=row[1],
+ medium=row[2],
+ client_secret=row[3],
+ last_send_attempt=row[4],
+ validated_at=row[5],
+ )
return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
@@ -1252,12 +1316,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
txn.execute(sql, [])
- res = self.db_pool.cursor_to_dict(txn)
- if res:
- for user in res:
- self.set_expiration_date_for_user_txn(
- txn, user["name"], use_delta=True
- )
+ for (name,) in txn.fetchall():
+ self.set_expiration_date_for_user_txn(txn, name, use_delta=True)
await self.db_pool.runInteraction(
"get_users_with_no_expiration_date",
@@ -1457,7 +1517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def get_registration_tokens(
self, valid: Optional[bool] = None
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]:
"""List all registration tokens. Used by the admin API.
Args:
@@ -1466,34 +1526,48 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Default is None: return all tokens regardless of validity.
Returns:
- A list of dicts, each containing details of a token.
+ A list of tuples containing:
+ * The token
+ * The number of users allowed (or None)
+ * Whether it is pending
+ * Whether it has been completed
+ * An expiry time (or None if no expiry)
"""
def select_registration_tokens_txn(
txn: LoggingTransaction, now: int, valid: Optional[bool]
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]:
if valid is None:
# Return all tokens regardless of validity
- txn.execute("SELECT * FROM registration_tokens")
+ txn.execute(
+ """
+ SELECT token, uses_allowed, pending, completed, expiry_time
+ FROM registration_tokens
+ """
+ )
elif valid:
# Select valid tokens only
- sql = (
- "SELECT * FROM registration_tokens WHERE "
- "(uses_allowed > pending + completed OR uses_allowed IS NULL) "
- "AND (expiry_time > ? OR expiry_time IS NULL)"
- )
+ sql = """
+ SELECT token, uses_allowed, pending, completed, expiry_time
+ FROM registration_tokens
+ WHERE (uses_allowed > pending + completed OR uses_allowed IS NULL)
+ AND (expiry_time > ? OR expiry_time IS NULL)
+ """
txn.execute(sql, [now])
else:
# Select invalid tokens only
- sql = (
- "SELECT * FROM registration_tokens WHERE "
- "uses_allowed <= pending + completed OR expiry_time <= ?"
- )
+ sql = """
+ SELECT token, uses_allowed, pending, completed, expiry_time
+ FROM registration_tokens
+ WHERE uses_allowed <= pending + completed OR expiry_time <= ?
+ """
txn.execute(sql, [now])
- return self.db_pool.cursor_to_dict(txn)
+ return cast(
+ List[Tuple[str, Optional[int], int, int, Optional[int]]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"select_registration_tokens",
@@ -1963,11 +2037,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ row = txn.fetchone()
+ assert row is not None
# We cast to bool because the value returned by the database engine might
# be an integer if we're using SQLite.
- return bool(rows[0]["approved"])
+ return bool(row[0])
return await self.db_pool.runInteraction(
desc="is_user_pending_approval",
@@ -2045,22 +2120,22 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
(last_user, batch_size),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return True, 0
rows_processed_nb = 0
- for user in rows:
- if not user["count_tokens"] and not user["count_threepids"]:
- self.set_user_deactivated_status_txn(txn, user["name"], True)
+ for name, count_tokens, count_threepids in rows:
+ if not count_tokens and not count_threepids:
+ self.set_user_deactivated_status_txn(txn, name, True)
rows_processed_nb += 1
logger.info("Marked %d rows as deactivated", rows_processed_nb)
self.db_pool.updates._background_update_progress_txn(
- txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
+ txn, "users_set_deactivated_flag", {"user_id": rows[-1][0]}
)
if batch_size > len(rows):
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index c04d45bdb5..d0bc78b2e3 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -48,6 +48,8 @@ from synapse.storage.databases.main.stream import (
)
from synapse.storage.engines import PostgresEngine, Psycopg2Engine
from synapse.types import JsonDict, StreamKeyType, StreamToken
+from synapse.storage.engines import PostgresEngine
+from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -314,7 +316,7 @@ class RelationsWorkerStore(SQLBaseStore):
room_key=next_key,
presence_key=0,
typing_key=0,
- receipt_key=0,
+ receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0,
push_rules_key=0,
to_device_key=0,
@@ -349,16 +351,19 @@ class RelationsWorkerStore(SQLBaseStore):
def get_all_relation_ids_for_event_with_types_txn(
txn: LoggingTransaction,
) -> List[str]:
- rows = self.db_pool.simple_select_many_txn(
- txn=txn,
- table="event_relations",
- column="relation_type",
- iterable=relation_types,
- keyvalues={"relates_to_id": event_id},
- retcols=["event_id"],
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn=txn,
+ table="event_relations",
+ column="relation_type",
+ iterable=relation_types,
+ keyvalues={"relates_to_id": event_id},
+ retcols=["event_id"],
+ ),
)
- return [row["event_id"] for row in rows]
+ return [row[0] for row in rows]
return await self.db_pool.runInteraction(
desc="get_all_relation_ids_for_event_with_types",
@@ -381,14 +386,17 @@ class RelationsWorkerStore(SQLBaseStore):
def get_all_relation_ids_for_event_txn(
txn: LoggingTransaction,
) -> List[str]:
- rows = self.db_pool.simple_select_list_txn(
- txn=txn,
- table="event_relations",
- keyvalues={"relates_to_id": event_id},
- retcols=["event_id"],
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_list_txn(
+ txn=txn,
+ table="event_relations",
+ keyvalues={"relates_to_id": event_id},
+ retcols=["event_id"],
+ ),
)
- return [row["event_id"] for row in rows]
+ return [row[0] for row in rows]
return await self.db_pool.runInteraction(
desc="get_all_relation_ids_for_event",
@@ -458,7 +466,7 @@ class RelationsWorkerStore(SQLBaseStore):
)
return result is not None
- @cached()
+ @cached() # type: ignore[synapse-@cached-mutable]
async def get_references_for_event(self, event_id: str) -> List[JsonDict]:
raise NotImplementedError()
@@ -512,11 +520,12 @@ class RelationsWorkerStore(SQLBaseStore):
"_get_references_for_events_txn", _get_references_for_events_txn
)
- @cached()
+ @cached() # type: ignore[synapse-@cached-mutable]
def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
raise NotImplementedError()
- @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
+ # TODO: This returns a mutable object, which is generally bad.
+ @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") # type: ignore[synapse-@cached-mutable]
async def get_applicable_edits(
self, event_ids: Collection[str]
) -> Mapping[str, Optional[EventBase]]:
@@ -598,11 +607,12 @@ class RelationsWorkerStore(SQLBaseStore):
for original_event_id in event_ids
}
- @cached()
+ @cached() # type: ignore[synapse-@cached-mutable]
def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]:
raise NotImplementedError()
- @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
+ # TODO: This returns a mutable object, which is generally bad.
+ @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") # type: ignore[synapse-@cached-mutable]
async def get_thread_summaries(
self, event_ids: Collection[str]
) -> Mapping[str, Optional[Tuple[int, EventBase]]]:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 719e11aea6..afb880532e 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -78,6 +78,31 @@ class RatelimitOverride:
burst_count: int
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class LargestRoomStats:
+ room_id: str
+ name: Optional[str]
+ canonical_alias: Optional[str]
+ joined_members: int
+ join_rules: Optional[str]
+ guest_access: Optional[str]
+ history_visibility: Optional[str]
+ state_events: int
+ avatar: Optional[str]
+ topic: Optional[str]
+ room_type: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RoomStats(LargestRoomStats):
+ joined_local_members: int
+ version: Optional[str]
+ creator: Optional[str]
+ encryption: Optional[str]
+ federatable: bool
+ public: bool
+
+
class RoomSortOrder(Enum):
"""
Enum to define the sorting method used when returning rooms with get_rooms_paginate
@@ -204,7 +229,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
allow_none=True,
)
- async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
+ async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
"""Retrieve room with statistics.
Args:
@@ -215,7 +240,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def get_room_with_stats_txn(
txn: LoggingTransaction, room_id: str
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[RoomStats]:
sql = """
SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
@@ -229,15 +254,28 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
WHERE room_id = ?
"""
txn.execute(sql, [room_id])
- # Catch error if sql returns empty result to return "None" instead of an error
- try:
- res = self.db_pool.cursor_to_dict(txn)[0]
- except IndexError:
+ row = txn.fetchone()
+ if not row:
return None
-
- res["federatable"] = bool(res["federatable"])
- res["public"] = bool(res["public"])
- return res
+ return RoomStats(
+ room_id=row[0],
+ name=row[1],
+ canonical_alias=row[2],
+ joined_members=row[3],
+ joined_local_members=row[4],
+ version=row[5],
+ creator=row[6],
+ encryption=row[7],
+ federatable=bool(row[8]),
+ public=bool(row[9]),
+ join_rules=row[10],
+ guest_access=row[11],
+ history_visibility=row[12],
+ state_events=row[13],
+ avatar=row[14],
+ topic=row[15],
+ room_type=row[16],
+ )
return await self.db_pool.runInteraction(
"get_room_with_stats", get_room_with_stats_txn, room_id
@@ -368,7 +406,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
bounds: Optional[Tuple[int, str]],
forwards: bool,
ignore_non_federatable: bool = False,
- ) -> List[Dict[str, Any]]:
+ ) -> List[LargestRoomStats]:
"""Gets the largest public rooms (where largest is in terms of joined
members, as tracked in the statistics table).
@@ -505,20 +543,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _get_largest_public_rooms_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, Any]]:
+ ) -> List[LargestRoomStats]:
txn.execute(sql, query_args)
- results = self.db_pool.cursor_to_dict(txn)
+ results = [
+ LargestRoomStats(
+ room_id=r[0],
+ name=r[1],
+ canonical_alias=r[3],
+ joined_members=r[4],
+ join_rules=r[8],
+ guest_access=r[7],
+ history_visibility=r[6],
+ state_events=0,
+ avatar=r[5],
+ topic=r[2],
+ room_type=r[9],
+ )
+ for r in txn
+ ]
if not forwards:
results.reverse()
return results
- ret_val = await self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn
)
- return ret_val
@cached(max_entries=10000)
async def is_room_blocked(self, room_id: str) -> Optional[bool]:
@@ -831,7 +883,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def get_retention_policy_for_room_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, Optional[int]]]:
+ ) -> Optional[Tuple[Optional[int], Optional[int]]]:
txn.execute(
"""
SELECT min_lifetime, max_lifetime FROM room_retention
@@ -841,7 +893,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
(room_id,),
)
- return self.db_pool.cursor_to_dict(txn)
+ return cast(Optional[Tuple[Optional[int], Optional[int]]], txn.fetchone())
ret = await self.db_pool.runInteraction(
"get_retention_policy_for_room",
@@ -856,8 +908,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
max_lifetime=self.config.retention.retention_default_max_lifetime,
)
- min_lifetime = ret[0]["min_lifetime"]
- max_lifetime = ret[0]["max_lifetime"]
+ min_lifetime, max_lifetime = ret
# If one of the room's policy's attributes isn't defined, use the matching
# attribute from the default policy.
@@ -1162,14 +1213,13 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql, args)
- rows = self.db_pool.cursor_to_dict(txn)
- rooms_dict = {}
-
- for row in rows:
- rooms_dict[row["room_id"]] = RetentionPolicy(
- min_lifetime=row["min_lifetime"],
- max_lifetime=row["max_lifetime"],
+ rooms_dict = {
+ room_id: RetentionPolicy(
+ min_lifetime=min_lifetime,
+ max_lifetime=max_lifetime,
)
+ for room_id, min_lifetime, max_lifetime in txn
+ }
if include_null:
# If required, do a second query that retrieves all of the rooms we know
@@ -1178,13 +1228,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql)
- rows = self.db_pool.cursor_to_dict(txn)
-
# If a room isn't already in the dict (i.e. it doesn't have a retention
# policy in its state), add it with a null policy.
- for row in rows:
- if row["room_id"] not in rooms_dict:
- rooms_dict[row["room_id"]] = RetentionPolicy()
+ for (room_id,) in txn:
+ if room_id not in rooms_dict:
+ rooms_dict[room_id] = RetentionPolicy()
return rooms_dict
@@ -1236,28 +1284,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
"""
room_servers: Dict[str, PartialStateResyncInfo] = {}
- rows = await self.db_pool.simple_select_list(
- table="partial_state_rooms",
- keyvalues={},
- retcols=("room_id", "joined_via"),
- desc="get_server_which_served_partial_join",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="partial_state_rooms",
+ keyvalues={},
+ retcols=("room_id", "joined_via"),
+ desc="get_server_which_served_partial_join",
+ ),
)
- for row in rows:
- room_id = row["room_id"]
- joined_via = row["joined_via"]
+ for room_id, joined_via in rows:
room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via)
- rows = await self.db_pool.simple_select_list(
- "partial_state_rooms_servers",
- keyvalues=None,
- retcols=("room_id", "server_name"),
- desc="get_partial_state_rooms",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ "partial_state_rooms_servers",
+ keyvalues=None,
+ retcols=("room_id", "server_name"),
+ desc="get_partial_state_rooms",
+ ),
)
- for row in rows:
- room_id = row["room_id"]
- server_name = row["server_name"]
+ for room_id, server_name in rows:
entry = room_servers.get(room_id)
if entry is None:
# There is a foreign key constraint which enforces that every room_id in
@@ -1300,14 +1350,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
complete.
"""
- rows: List[Dict[str, str]] = await self.db_pool.simple_select_many_batch(
- table="partial_state_rooms",
- column="room_id",
- iterable=room_ids,
- retcols=("room_id",),
- desc="is_partial_state_room_batched",
- )
- partial_state_rooms = {row_dict["room_id"] for row_dict in rows}
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="partial_state_rooms",
+ column="room_id",
+ iterable=room_ids,
+ retcols=("room_id",),
+ desc="is_partial_state_room_batched",
+ ),
+ )
+ partial_state_rooms = {row[0] for row in rows}
return {room_id: room_id in partial_state_rooms for room_id in room_ids}
async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
@@ -1703,24 +1756,24 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
(last_room, batch_size),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return True
- for row in rows:
- if not row["json"]:
+ for room_id, event_id, json in rows:
+ if not json:
retention_policy = {}
else:
- ev = db_to_json(row["json"])
+ ev = db_to_json(json)
retention_policy = ev["content"]
self.db_pool.simple_insert_txn(
txn=txn,
table="room_retention",
values={
- "room_id": row["room_id"],
- "event_id": row["event_id"],
+ "room_id": room_id,
+ "event_id": event_id,
"min_lifetime": retention_policy.get("min_lifetime"),
"max_lifetime": retention_policy.get("max_lifetime"),
},
@@ -1729,7 +1782,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
logger.info("Inserted %d rows into room_retention", len(rows))
self.db_pool.updates._background_update_progress_txn(
- txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
+ txn, "insert_room_retention", {"room_id": rows[-1][0]}
)
if batch_size > len(rows):
@@ -2215,7 +2268,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
txn,
table="partial_state_rooms_servers",
keys=("room_id", "server_name"),
- values=((room_id, s) for s in servers),
+ values=[(room_id, s) for s in servers],
)
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
self._invalidate_cache_and_stream(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 3755773faa..1ed7f2d0ef 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -27,6 +27,7 @@ from typing import (
Set,
Tuple,
Union,
+ cast,
)
import attr
@@ -275,7 +276,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
_get_users_in_room_with_profiles,
)
- @cached(max_entries=100000)
+ @cached(max_entries=100000) # type: ignore[synapse-@cached-mutable]
async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
"""Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
@@ -481,6 +482,22 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
desc="get_local_users_in_room",
)
+ async def get_local_users_related_to_room(
+ self, room_id: str
+ ) -> List[Tuple[str, str]]:
+ """
+ Retrieves a list of the current roommembers who are local to the server and their membership status.
+ """
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="local_current_membership",
+ keyvalues={"room_id": room_id},
+ retcols=("user_id", "membership"),
+ desc="get_local_users_in_room",
+ ),
+ )
+
async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
"""
Check whether a given local user is currently joined to the given room.
@@ -683,25 +700,28 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
Map from user_id to set of rooms that is currently in.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="current_state_events",
- column="state_key",
- iterable=user_ids,
- retcols=(
- "state_key",
- "room_id",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="current_state_events",
+ column="state_key",
+ iterable=user_ids,
+ retcols=(
+ "state_key",
+ "room_id",
+ ),
+ keyvalues={
+ "type": EventTypes.Member,
+ "membership": Membership.JOIN,
+ },
+ desc="get_rooms_for_users",
),
- keyvalues={
- "type": EventTypes.Member,
- "membership": Membership.JOIN,
- },
- desc="get_rooms_for_users",
)
user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
- for row in rows:
- user_rooms[row["state_key"]].add(row["room_id"])
+ for state_key, room_id in rows:
+ user_rooms[state_key].add(room_id)
return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
@@ -892,17 +912,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
Map from event ID to `user_id`, or None if event is not a join.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="room_memberships",
- column="event_id",
- iterable=event_ids,
- retcols=("user_id", "event_id"),
- keyvalues={"membership": Membership.JOIN},
- batch_size=1000,
- desc="_get_user_ids_from_membership_event_ids",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=event_ids,
+ retcols=("event_id", "user_id"),
+ keyvalues={"membership": Membership.JOIN},
+ batch_size=1000,
+ desc="_get_user_ids_from_membership_event_ids",
+ ),
)
- return {row["event_id"]: row["user_id"] for row in rows}
+ return dict(rows)
@cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -933,7 +956,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
like_clause = "%:" + host
rows = await self.db_pool.execute(
- "is_host_joined", None, sql, membership, room_id, like_clause
+ "is_host_joined", sql, membership, room_id, like_clause
)
if not rows:
@@ -1063,15 +1086,19 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
for fully-joined rooms.
"""
- rows = await self.db_pool.simple_select_list(
- "current_state_events",
- keyvalues={"room_id": room_id},
- retcols=("event_id", "membership"),
- desc="has_completed_background_updates",
+ rows = cast(
+ List[Tuple[str, Optional[str]]],
+ await self.db_pool.simple_select_list(
+ "current_state_events",
+ keyvalues={"room_id": room_id},
+ retcols=("event_id", "membership"),
+ desc="has_completed_background_updates",
+ ),
)
- return {row["event_id"]: row["membership"] for row in rows}
+ return dict(rows)
- @cached(max_entries=10000)
+ # TODO This returns a mutable object, which is generally confusing when using a cache.
+ @cached(max_entries=10000) # type: ignore[synapse-@cached-mutable]
def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
return _JoinedHostsCache()
@@ -1157,7 +1184,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
AND forgotten = 0;
"""
- rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+ rows = await self.db_pool.execute("is_forgotten_room", sql, room_id)
# `count(*)` returns always an integer
# If any rows still exist it means someone has not forgotten this room yet
@@ -1201,21 +1228,22 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
membership event, otherwise the value is None.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="room_memberships",
- column="event_id",
- iterable=member_event_ids,
- retcols=("user_id", "membership", "event_id"),
- keyvalues={},
- batch_size=500,
- desc="get_membership_from_event_ids",
+ rows = cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=member_event_ids,
+ retcols=("user_id", "membership", "event_id"),
+ keyvalues={},
+ batch_size=500,
+ desc="get_membership_from_event_ids",
+ ),
)
return {
- row["event_id"]: EventIdMembership(
- membership=row["membership"], user_id=row["user_id"]
- )
- for row in rows
+ event_id: EventIdMembership(membership=membership, user_id=user_id)
+ for user_id, membership, event_id in rows
}
async def is_local_host_in_room_ignoring_users(
@@ -1348,18 +1376,16 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return 0
- min_stream_id = rows[-1]["stream_ordering"]
+ min_stream_id = rows[-1][0]
to_update = []
- for row in rows:
- event_id = row["event_id"]
- room_id = row["room_id"]
+ for _, event_id, room_id, json in rows:
try:
- event_json = db_to_json(row["json"])
+ event_json = db_to_json(json)
content = event_json["content"]
except Exception:
continue
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index d45d2ecc98..4c4112e3b2 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -26,6 +26,7 @@ from typing import (
Set,
Tuple,
Union,
+ cast,
)
import attr
@@ -105,7 +106,7 @@ class SearchWorkerStore(SQLBaseStore):
txn,
table="event_search",
keys=("event_id", "room_id", "key", "value"),
- values=(
+ values=[
(
entry.event_id,
entry.room_id,
@@ -113,7 +114,7 @@ class SearchWorkerStore(SQLBaseStore):
_clean_value_for_search(entry.value),
)
for entry in entries
- ),
+ ],
)
else:
@@ -179,22 +180,24 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# store_search_entries_txn with a generator function, but that
# would mean having two cursors open on the database at once.
# Instead we just build a list of results.
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return 0
- min_stream_id = rows[-1]["stream_ordering"]
+ min_stream_id = rows[-1][0]
event_search_rows = []
- for row in rows:
+ for (
+ stream_ordering,
+ event_id,
+ room_id,
+ etype,
+ json,
+ origin_server_ts,
+ ) in rows:
try:
- event_id = row["event_id"]
- room_id = row["room_id"]
- etype = row["type"]
- stream_ordering = row["stream_ordering"]
- origin_server_ts = row["origin_server_ts"]
try:
- event_json = db_to_json(row["json"])
+ event_json = db_to_json(json)
content = event_json["content"]
except Exception:
continue
@@ -504,16 +507,18 @@ class SearchStore(SearchBackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
- results = await self.db_pool.execute(
- "search_msgs", self.db_pool.cursor_to_dict, sql, *args
+ # List of tuples of (rank, room_id, event_id).
+ results = cast(
+ List[Tuple[Union[int, float], str, str]],
+ await self.db_pool.execute("search_msgs", sql, *args),
)
- results = list(filter(lambda row: row["room_id"] in room_ids, results))
+ results = list(filter(lambda row: row[1] in room_ids, results))
# We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined]
- [r["event_id"] for r in results],
+ [r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block,
)
@@ -525,16 +530,18 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = await self.db_pool.execute(
- "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
+ # List of tuples of (room_id, count).
+ count_results = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
)
- count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+ count = sum(row[1] for row in count_results if row[0] in room_ids)
return {
"results": [
- {"event": event_map[r["event_id"]], "rank": r["rank"]}
+ {"event": event_map[r[2]], "rank": r[0]}
for r in results
- if r["event_id"] in event_map
+ if r[2] in event_map
],
"highlights": highlights,
"count": count,
@@ -602,7 +609,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = search_term
sql = """
SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
- origin_server_ts, stream_ordering, room_id, event_id
+ room_id, event_id, origin_server_ts, stream_ordering
FROM event_search
WHERE vector @@ websearch_to_tsquery('english', ?) AND
"""
@@ -663,16 +670,18 @@ class SearchStore(SearchBackgroundUpdateStore):
# mypy expects to append only a `str`, not an `int`
args.append(limit)
- results = await self.db_pool.execute(
- "search_rooms", self.db_pool.cursor_to_dict, sql, *args
+ # List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering).
+ results = cast(
+ List[Tuple[Union[int, float], str, str, int, int]],
+ await self.db_pool.execute("search_rooms", sql, *args),
)
- results = list(filter(lambda row: row["room_id"] in room_ids, results))
+ results = list(filter(lambda row: row[1] in room_ids, results))
# We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined]
- [r["event_id"] for r in results],
+ [r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block,
)
@@ -684,22 +693,23 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = await self.db_pool.execute(
- "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
+ # List of tuples of (room_id, count).
+ count_results = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
)
- count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+ count = sum(row[1] for row in count_results if row[0] in room_ids)
return {
"results": [
{
- "event": event_map[r["event_id"]],
- "rank": r["rank"],
- "pagination_token": "%s,%s"
- % (r["origin_server_ts"], r["stream_ordering"]),
+ "event": event_map[r[2]],
+ "rank": r[0],
+ "pagination_token": "%s,%s" % (r[3], r[4]),
}
for r in results
- if r["event_id"] in event_map
+ if r[2] in event_map
],
"highlights": highlights,
"count": count,
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5eaaff5b68..598025dd91 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -20,10 +20,12 @@ from typing import (
Collection,
Dict,
Iterable,
+ List,
Mapping,
Optional,
Set,
Tuple,
+ cast,
)
import attr
@@ -388,16 +390,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Raises:
RuntimeError if the state is unknown at any of the given events
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="event_to_state_groups",
- column="event_id",
- iterable=event_ids,
- keyvalues={},
- retcols=("event_id", "state_group"),
- desc="_get_state_group_for_events",
+ rows = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="event_to_state_groups",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=("event_id", "state_group"),
+ desc="_get_state_group_for_events",
+ ),
)
- res = {row["event_id"]: row["state_group"] for row in rows}
+ res = dict(rows)
for e in event_ids:
if e not in res:
raise RuntimeError("No state group for unknown or outlier event %s" % e)
@@ -415,16 +420,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
The subset of state groups that are referenced.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="event_to_state_groups",
- column="state_group",
- iterable=state_groups,
- keyvalues={},
- retcols=("DISTINCT state_group",),
- desc="get_referenced_state_groups",
+ rows = cast(
+ List[Tuple[int]],
+ await self.db_pool.simple_select_many_batch(
+ table="event_to_state_groups",
+ column="state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("DISTINCT state_group",),
+ desc="get_referenced_state_groups",
+ ),
)
- return {row["state_group"] for row in rows}
+ return {row[0] for row in rows}
async def update_state_for_partial_state_event(
self,
@@ -624,16 +632,22 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
# potentially stale, since there may have been a period where the
# server didn't share a room with the remote user and therefore may
# have missed any device updates.
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="current_state_events",
- column="room_id",
- iterable=to_delete,
- keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
- retcols=("state_key",),
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="current_state_events",
+ column="room_id",
+ iterable=to_delete,
+ keyvalues={
+ "type": EventTypes.Member,
+ "membership": Membership.JOIN,
+ },
+ retcols=("state_key",),
+ ),
)
- potentially_left_users = {row["state_key"] for row in rows}
+ potentially_left_users = {row[0] for row in rows}
# Now lets actually delete the rooms from the DB.
self.db_pool.simple_delete_many_txn(
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 445213e12a..3151186e0c 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -13,7 +13,9 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List, Tuple
+from typing import List, Optional, Tuple
+
+import attr
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
@@ -22,6 +24,20 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class StateDelta:
+ stream_id: int
+ room_id: str
+ event_type: str
+ state_key: str
+
+ event_id: Optional[str]
+ """new event_id for this state key. None if the state has been deleted."""
+
+ prev_event_id: Optional[str]
+ """previous event_id for this state key. None if it's new state."""
+
+
class StateDeltasStore(SQLBaseStore):
# This class must be mixed in with a child class which provides the following
# attribute. TODO: can we get static analysis to enforce this?
@@ -29,31 +45,21 @@ class StateDeltasStore(SQLBaseStore):
async def get_partial_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
- ) -> Tuple[int, List[Dict[str, Any]]]:
+ ) -> Tuple[int, List[StateDelta]]:
"""Fetch a list of room state changes since the given stream id
- Each entry in the result contains the following fields:
- - stream_id (int)
- - room_id (str)
- - type (str): event type
- - state_key (str):
- - event_id (str|None): new event_id for this state key. None if the
- state has been deleted.
- - prev_event_id (str|None): previous event_id for this state key. None
- if it's new state.
-
This may be the partial state if we're lazy joining the room.
Args:
prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted
- - ie, an upper limit to return changes from.
+ - ie, an upper limit to return changes from.
Returns:
A tuple consisting of:
- - the stream id which these results go up to
- - list of current_state_delta_stream rows. If it is empty, we are
- up to date.
+ - the stream id which these results go up to
+ - list of current_state_delta_stream rows. If it is empty, we are
+ up to date.
"""
prev_stream_id = int(prev_stream_id)
@@ -72,7 +78,7 @@ class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas_txn(
txn: LoggingTransaction,
- ) -> Tuple[int, List[Dict[str, Any]]]:
+ ) -> Tuple[int, List[StateDelta]]:
# First we calculate the max stream id that will give us less than
# N results.
# We arbitrarily limit to 100 stream_id entries to ensure we don't
@@ -112,7 +118,17 @@ class StateDeltasStore(SQLBaseStore):
ORDER BY stream_id ASC
"""
txn.execute(sql, (prev_stream_id, clipped_stream_id))
- return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
+ return clipped_stream_id, [
+ StateDelta(
+ stream_id=row[0],
+ room_id=row[1],
+ event_type=row[2],
+ state_key=row[3],
+ event_id=row[4],
+ prev_event_id=row[5],
+ )
+ for row in txn.fetchall()
+ ]
return await self.db_pool.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 9d403919e4..e96c9b0486 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -506,25 +506,28 @@ class StatsStore(StateDeltasStore):
) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined]
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="current_state_events",
- column="type",
- iterable=[
- EventTypes.Create,
- EventTypes.JoinRules,
- EventTypes.RoomHistoryVisibility,
- EventTypes.RoomEncryption,
- EventTypes.Name,
- EventTypes.Topic,
- EventTypes.RoomAvatar,
- EventTypes.CanonicalAlias,
- ],
- keyvalues={"room_id": room_id, "state_key": ""},
- retcols=["event_id"],
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="current_state_events",
+ column="type",
+ iterable=[
+ EventTypes.Create,
+ EventTypes.JoinRules,
+ EventTypes.RoomHistoryVisibility,
+ EventTypes.RoomEncryption,
+ EventTypes.Name,
+ EventTypes.Topic,
+ EventTypes.RoomAvatar,
+ EventTypes.CanonicalAlias,
+ ],
+ keyvalues={"room_id": room_id, "state_key": ""},
+ retcols=["event_id"],
+ ),
)
- event_ids = cast(List[str], [row["event_id"] for row in rows])
+ event_ids = [row[0] for row in rows]
txn.execute(
"""
@@ -676,7 +679,7 @@ class StatsStore(StateDeltasStore):
order_by: Optional[str] = UserSortOrder.USER_ID.value,
direction: Direction = Direction.FORWARDS,
search_term: Optional[str] = None,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
"""Function to retrieve a paginated list of users and their uploaded local media
(size and number). This will return a json list of users and the
total number of users matching the filter criteria.
@@ -689,14 +692,19 @@ class StatsStore(StateDeltasStore):
order_by: the sort order of the returned list
direction: sort ascending or descending
search_term: a string to filter user names by
+
Returns:
- A list of user dicts and an integer representing the total number of
- users that exist given this query
+ A tuple of:
+ A list of tuples of user information (the user ID, displayname,
+ total number of media, total length of media) and
+
+ An integer representing the total number of users that exist
+ given this query
"""
def get_users_media_usage_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
filters = []
args: list = []
@@ -770,7 +778,7 @@ class StatsStore(StateDeltasStore):
args += [limit, start]
txn.execute(sql, args)
- users = self.db_pool.cursor_to_dict(txn)
+ users = cast(List[Tuple[str, Optional[str], int, int]], txn.fetchall())
return users, count
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 5a3611c415..2225f8272d 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -266,7 +266,7 @@ def generate_next_token(
# when we are going backwards so we subtract one from the
# stream part.
last_stream_ordering -= 1
- return RoomStreamToken(last_topo_ordering, last_stream_ordering)
+ return RoomStreamToken(topological=last_topo_ordering, stream=last_stream_ordering)
def _make_generic_sql_bound(
@@ -558,7 +558,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if p > min_pos
}
- return RoomStreamToken(None, min_pos, immutabledict(positions))
+ return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
async def get_room_events_stream_for_rooms(
self,
@@ -708,7 +708,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse()
if rows:
- key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
+ key = RoomStreamToken(stream=min(r.stream_ordering for r in rows))
else:
# Assume we didn't get anything because there was nothing to
# get.
@@ -969,7 +969,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
topo = await self.db_pool.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
- return RoomStreamToken(topo, stream_ordering)
+ return RoomStreamToken(topological=topo, stream=stream_ordering)
@overload
def get_stream_id_for_event_txn(
@@ -1033,7 +1033,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
)
- return RoomStreamToken(row["topological_ordering"], row["stream_ordering"])
+ return RoomStreamToken(
+ topological=row["topological_ordering"], stream=row["stream_ordering"]
+ )
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
"""Gets the topological token in a room after or at the given stream
@@ -1076,7 +1078,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
row = await self.db_pool.execute(
- "get_current_topological_token", None, sql, room_id, room_id, stream_key
+ "get_current_topological_token", sql, room_id, room_id, stream_key
)
return row[0][0] if row else 0
@@ -1114,8 +1116,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
else:
topo = None
internal = event.internal_metadata
- internal.before = RoomStreamToken(topo, stream - 1)
- internal.after = RoomStreamToken(topo, stream)
+ internal.before = RoomStreamToken(topological=topo, stream=stream - 1)
+ internal.after = RoomStreamToken(topological=topo, stream=stream)
internal.order = (int(topo) if topo else 0, int(stream))
async def get_events_around(
@@ -1191,11 +1193,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
- results["topological_ordering"] - 1, results["stream_ordering"]
+ topological=results["topological_ordering"] - 1,
+ stream=results["stream_ordering"],
)
after_token = RoomStreamToken(
- results["topological_ordering"], results["stream_ordering"]
+ topological=results["topological_ordering"],
+ stream=results["stream_ordering"],
)
rows, start_token = self._paginate_room_events_txn(
@@ -1612,3 +1616,49 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcol="instance_name",
desc="get_name_from_instance_id",
)
+
+ async def get_timeline_gaps(
+ self,
+ room_id: str,
+ from_token: Optional[RoomStreamToken],
+ to_token: RoomStreamToken,
+ ) -> Optional[RoomStreamToken]:
+ """Check if there is a gap, and return a token that marks the position
+ of the gap in the stream.
+ """
+
+ sql = """
+ SELECT instance_name, stream_ordering
+ FROM timeline_gaps
+ WHERE room_id = ? AND ? < stream_ordering AND stream_ordering <= ?
+ ORDER BY stream_ordering
+ """
+
+ rows = await self.db_pool.execute(
+ "get_timeline_gaps",
+ sql,
+ room_id,
+ from_token.stream if from_token else 0,
+ to_token.get_max_stream_pos(),
+ )
+
+ if not rows:
+ return None
+
+ positions = [
+ PersistedEventPosition(instance_name, stream_ordering)
+ for instance_name, stream_ordering in rows
+ ]
+ if from_token:
+ positions = [p for p in positions if p.persisted_after(from_token)]
+
+ positions = [p for p in positions if not p.persisted_after(to_token)]
+
+ if positions:
+ # We return a stream token that ensures the event *at* the position
+ # of the gap is included (as the gap is *before* the persisted
+ # event).
+ last_position = positions[-1]
+ return RoomStreamToken(stream=last_position.stream - 1)
+
+ return None
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 61403a98cf..7deda7790e 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -45,14 +45,17 @@ class TagsWorkerStore(AccountDataWorkerStore):
tag content.
"""
- rows = await self.db_pool.simple_select_list(
- "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
+ rows = cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.simple_select_list(
+ "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
+ ),
)
tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
- for row in rows:
- room_tags = tags_by_room.setdefault(row["room_id"], {})
- room_tags[row["tag"]] = db_to_json(row["content"])
+ for room_id, tag, content in rows:
+ room_tags = tags_by_room.setdefault(room_id, {})
+ room_tags[tag] = db_to_json(content)
return tags_by_room
async def get_all_updated_tags(
@@ -161,13 +164,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
Returns:
A mapping of tags to tag content.
"""
- rows = await self.db_pool.simple_select_list(
- table="room_tags",
- keyvalues={"user_id": user_id, "room_id": room_id},
- retcols=("tag", "content"),
- desc="get_tags_for_room",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="room_tags",
+ keyvalues={"user_id": user_id, "room_id": room_id},
+ retcols=("tag", "content"),
+ desc="get_tags_for_room",
+ ),
)
- return {row["tag"]: db_to_json(row["content"]) for row in rows}
+ return {tag: db_to_json(content) for tag, content in rows}
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py
index 5c5372a825..5555b53575 100644
--- a/synapse/storage/databases/main/task_scheduler.py
+++ b/synapse/storage/databases/main/task_scheduler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, List, Optional, Tuple, cast
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -27,6 +27,8 @@ from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
+ScheduledTaskRow = Tuple[str, str, str, int, str, str, str, str]
+
class TaskSchedulerWorkerStore(SQLBaseStore):
def __init__(
@@ -38,13 +40,18 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
super().__init__(database, db_conn, hs)
@staticmethod
- def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask:
- row["status"] = TaskStatus(row["status"])
- if row["params"] is not None:
- row["params"] = db_to_json(row["params"])
- if row["result"] is not None:
- row["result"] = db_to_json(row["result"])
- return ScheduledTask(**row)
+ def _convert_row_to_task(row: ScheduledTaskRow) -> ScheduledTask:
+ task_id, action, status, timestamp, resource_id, params, result, error = row
+ return ScheduledTask(
+ id=task_id,
+ action=action,
+ status=TaskStatus(status),
+ timestamp=timestamp,
+ resource_id=resource_id,
+ params=db_to_json(params) if params is not None else None,
+ result=db_to_json(result) if result is not None else None,
+ error=error,
+ )
async def get_scheduled_tasks(
self,
@@ -68,7 +75,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
Returns: a list of `ScheduledTask`, ordered by increasing timestamps
"""
- def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[ScheduledTaskRow]:
clauses: List[str] = []
args: List[Any] = []
if resource_id:
@@ -101,7 +108,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
args.append(limit)
txn.execute(sql, args)
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[ScheduledTaskRow], txn.fetchall())
rows = await self.db_pool.runInteraction(
"get_scheduled_tasks", get_scheduled_tasks_txn
@@ -193,7 +200,22 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
desc="get_scheduled_task",
)
- return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
+ return (
+ TaskSchedulerWorkerStore._convert_row_to_task(
+ (
+ row["id"],
+ row["action"],
+ row["status"],
+ row["timestamp"],
+ row["resource_id"],
+ row["params"],
+ row["result"],
+ row["error"],
+ )
+ )
+ if row
+ else None
+ )
async def delete_scheduled_task(self, id: str) -> None:
"""Delete a specific task from its id.
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 8f70eff809..fecddb4144 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -211,18 +211,28 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
async def get_destination_retry_timings_batch(
self, destinations: StrCollection
) -> Mapping[str, Optional[DestinationRetryTimings]]:
- rows = await self.db_pool.simple_select_many_batch(
- table="destinations",
- iterable=destinations,
- column="destination",
- retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
- desc="get_destination_retry_timings_batch",
+ rows = cast(
+ List[Tuple[str, Optional[int], Optional[int], Optional[int]]],
+ await self.db_pool.simple_select_many_batch(
+ table="destinations",
+ iterable=destinations,
+ column="destination",
+ retcols=(
+ "destination",
+ "failure_ts",
+ "retry_last_ts",
+ "retry_interval",
+ ),
+ desc="get_destination_retry_timings_batch",
+ ),
)
return {
- row.pop("destination"): DestinationRetryTimings(**row)
- for row in rows
- if row["retry_last_ts"] and row["failure_ts"] and row["retry_interval"]
+ destination: DestinationRetryTimings(
+ failure_ts, retry_last_ts, retry_interval
+ )
+ for destination, failure_ts, retry_last_ts, retry_interval in rows
+ if retry_last_ts and failure_ts and retry_interval
}
async def set_destination_retry_timings(
@@ -468,7 +478,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
destination: Optional[str] = None,
order_by: str = DestinationSortOrder.DESTINATION.value,
direction: Direction = Direction.FORWARDS,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[
+ List[Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]],
+ int,
+ ]:
"""Function to retrieve a paginated list of destinations.
This will return a json list of destinations and the
total number of destinations matching the filter criteria.
@@ -480,13 +493,23 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
order_by: the sort order of the returned list
direction: sort ascending or descending
Returns:
- A tuple of a list of mappings from destination to information
+ A tuple of a list of tuples of destination information:
+ * destination
+ * retry_last_ts
+ * retry_interval
+ * failure_ts
+ * last_successful_stream_ordering
and a count of total destinations.
"""
def get_destinations_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[
+ List[
+ Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]
+ ],
+ int,
+ ]:
order_by_column = DestinationSortOrder(order_by).value
if direction == Direction.BACKWARDS:
@@ -513,7 +536,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
LIMIT ? OFFSET ?
"""
txn.execute(sql, args + [limit, start])
- destinations = self.db_pool.cursor_to_dict(txn)
+ destinations = cast(
+ List[
+ Tuple[
+ str, Optional[int], Optional[int], Optional[int], Optional[int]
+ ]
+ ],
+ txn.fetchall(),
+ )
return destinations, count
return await self.db_pool.runInteraction(
@@ -526,7 +556,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
start: int,
limit: int,
direction: Direction = Direction.FORWARDS,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, int]], int]:
"""Function to retrieve a paginated list of destination's rooms.
This will return a json list of rooms and the
total number of rooms.
@@ -537,12 +567,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
limit: number of rows to retrieve
direction: sort ascending or descending by room_id
Returns:
- A tuple of a dict of rooms and a count of total rooms.
+ A tuple of a list of room tuples and a count of total rooms.
+
+ Each room tuple is room_id, stream_ordering.
"""
def get_destination_rooms_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, int]], int]:
if direction == Direction.BACKWARDS:
order = "DESC"
else:
@@ -556,14 +588,17 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql, [destination])
count = cast(Tuple[int], txn.fetchone())[0]
- rooms = self.db_pool.simple_select_list_paginate_txn(
- txn=txn,
- table="destination_rooms",
- orderby="room_id",
- start=start,
- limit=limit,
- retcols=("room_id", "stream_ordering"),
- order_direction=order,
+ rooms = cast(
+ List[Tuple[str, int]],
+ self.db_pool.simple_select_list_paginate_txn(
+ txn=txn,
+ table="destination_rooms",
+ orderby="room_id",
+ start=start,
+ limit=limit,
+ retcols=("room_id", "stream_ordering"),
+ order_direction=order,
+ ),
)
return rooms, count
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index f38bedbbcd..8ab7c42c4a 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -169,13 +169,17 @@ class UIAuthWorkerStore(SQLBaseStore):
that auth-type.
"""
results = {}
- for row in await self.db_pool.simple_select_list(
- table="ui_auth_sessions_credentials",
- keyvalues={"session_id": session_id},
- retcols=("stage_type", "result"),
- desc="get_completed_ui_auth_stages",
- ):
- results[row["stage_type"]] = db_to_json(row["result"])
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_credentials",
+ keyvalues={"session_id": session_id},
+ retcols=("stage_type", "result"),
+ desc="get_completed_ui_auth_stages",
+ ),
+ )
+ for stage_type, result in rows:
+ results[stage_type] = db_to_json(result)
return results
@@ -295,13 +299,15 @@ class UIAuthWorkerStore(SQLBaseStore):
Returns:
List of user_agent/ip pairs
"""
- rows = await self.db_pool.simple_select_list(
- table="ui_auth_sessions_ips",
- keyvalues={"session_id": session_id},
- retcols=("user_agent", "ip"),
- desc="get_user_agents_ips_to_ui_auth_session",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id},
+ retcols=("user_agent", "ip"),
+ desc="get_user_agents_ips_to_ui_auth_session",
+ ),
)
- return [(row["user_agent"], row["ip"]) for row in rows]
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
@@ -337,13 +343,16 @@ class UIAuthWorkerStore(SQLBaseStore):
# If a registration token was used, decrement the pending counter
# before deleting the session.
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="ui_auth_sessions_credentials",
- column="session_id",
- iterable=session_ids,
- keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
- retcols=["result"],
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="ui_auth_sessions_credentials",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
+ retcols=["result"],
+ ),
)
# Get the tokens used and how much pending needs to be decremented by.
@@ -353,23 +362,25 @@ class UIAuthWorkerStore(SQLBaseStore):
# registration token stage for that session will be True.
# If a token was used to authenticate, but registration was
# never completed, the result will be the token used.
- token = db_to_json(r["result"])
+ token = db_to_json(r[0])
if isinstance(token, str):
token_counts[token] = token_counts.get(token, 0) + 1
# Update the `pending` counters.
if len(token_counts) > 0:
- token_rows = self.db_pool.simple_select_many_txn(
- txn,
- table="registration_tokens",
- column="token",
- iterable=list(token_counts.keys()),
- keyvalues={},
- retcols=["token", "pending"],
+ token_rows = cast(
+ List[Tuple[str, int]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="registration_tokens",
+ column="token",
+ iterable=list(token_counts.keys()),
+ keyvalues={},
+ retcols=["token", "pending"],
+ ),
)
- for token_row in token_rows:
- token = token_row["token"]
- new_pending = token_row["pending"] - token_counts[token]
+ for token, pending in token_rows:
+ new_pending = pending - token_counts[token]
self.db_pool.simple_update_one_txn(
txn,
table="registration_tokens",
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index ed41e52201..d4b86ed7a6 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -415,25 +415,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
# Next fetch their profiles. Note that not all users have profiles.
- profile_rows = self.db_pool.simple_select_many_txn(
- txn,
- table="profiles",
- column="full_user_id",
- iterable=list(users_to_insert),
- retcols=(
- "full_user_id",
- "displayname",
- "avatar_url",
+ profile_rows = cast(
+ List[Tuple[str, Optional[str], Optional[str]]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="profiles",
+ column="full_user_id",
+ iterable=list(users_to_insert),
+ retcols=(
+ "full_user_id",
+ "displayname",
+ "avatar_url",
+ ),
+ keyvalues={},
),
- keyvalues={},
)
profiles = {
- row["full_user_id"]: _UserDirProfile(
- row["full_user_id"],
- row["displayname"],
- row["avatar_url"],
- )
- for row in profile_rows
+ full_user_id: _UserDirProfile(full_user_id, displayname, avatar_url)
+ for full_user_id, displayname, avatar_url in profile_rows
}
profiles_to_insert = [
@@ -522,18 +521,21 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
and not self.get_if_app_services_interested_in_user(user) # type: ignore[attr-defined]
]
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="users",
- column="name",
- iterable=users,
- keyvalues={
- "deactivated": 0,
- },
- retcols=("name", "user_type"),
+ rows = cast(
+ List[Tuple[str, Optional[str]]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="users",
+ column="name",
+ iterable=users,
+ keyvalues={
+ "deactivated": 0,
+ },
+ retcols=("name", "user_type"),
+ ),
)
- return [row["name"] for row in rows if row["user_type"] != UserTypes.SUPPORT]
+ return [name for name, user_type in rows if user_type != UserTypes.SUPPORT]
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
"""Check if the room is either world_readable or publically joinable"""
@@ -1178,15 +1180,19 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
raise Exception("Unrecognized database engine")
results = cast(
- List[UserProfile],
- await self.db_pool.execute(
- "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
- ),
+ List[Tuple[str, Optional[str], Optional[str]]],
+ await self.db_pool.execute("search_user_dir", sql, *args),
)
limited = len(results) > limit
- return {"limited": limited, "results": results[0:limit]}
+ return {
+ "limited": limited,
+ "results": [
+ {"user_id": r[0], "display_name": r[1], "avatar_url": r[2]}
+ for r in results[0:limit]
+ ],
+ }
def _filter_text_for_index(text: str) -> str:
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index 06fcbe5e54..8bd58c6e3d 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Iterable, Mapping
+from typing import Iterable, List, Mapping, Tuple, cast
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
@@ -50,14 +50,17 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore):
Returns:
for each user, whether the user has requested erasure.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="erased_users",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id",),
- desc="are_users_erased",
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="erased_users",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id",),
+ desc="are_users_erased",
+ ),
)
- erased_users = {row["user_id"] for row in rows}
+ erased_users = {row[0] for row in rows}
return {u: u in erased_users for u in user_ids}
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index d2e942cbd3..2c3151526d 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -359,7 +359,6 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
if max_group is None:
rows = await self.db_pool.execute(
"_background_deduplicate_state",
- None,
"SELECT coalesce(max(id), 0) FROM state_groups",
)
max_group = rows[0][0]
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 6984d11352..182e429174 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -13,7 +13,17 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
import attr
@@ -144,16 +154,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
if not prev_group:
return _GetStateGroupDelta(None, None)
- delta_ids = self.db_pool.simple_select_list_txn(
- txn,
- table="state_groups_state",
- keyvalues={"state_group": state_group},
- retcols=("type", "state_key", "event_id"),
+ delta_ids = cast(
+ List[Tuple[str, str, str]],
+ self.db_pool.simple_select_list_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": state_group},
+ retcols=("type", "state_key", "event_id"),
+ ),
)
return _GetStateGroupDelta(
prev_group,
- {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
+ {
+ (event_type, state_key): event_id
+ for event_type, state_key, event_id in delta_ids
+ },
)
return await self.db_pool.runInteraction(
@@ -730,19 +746,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"[purge] found %i state groups to delete", len(state_groups_to_delete)
)
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="state_group_edges",
- column="prev_state_group",
- iterable=state_groups_to_delete,
- keyvalues={},
- retcols=("state_group",),
+ rows = cast(
+ List[Tuple[int]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ retcols=("state_group",),
+ ),
)
remaining_state_groups = {
- row["state_group"]
- for row in rows
- if row["state_group"] not in state_groups_to_delete
+ state_group
+ for state_group, in rows
+ if state_group not in state_groups_to_delete
}
logger.info(
@@ -799,16 +818,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
A mapping from state group to previous state group.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="state_group_edges",
- column="prev_state_group",
- iterable=state_groups,
- keyvalues={},
- retcols=("prev_state_group", "state_group"),
- desc="get_previous_state_groups",
+ rows = cast(
+ List[Tuple[int, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("state_group", "prev_state_group"),
+ desc="get_previous_state_groups",
+ ),
)
- return {row["state_group"]: row["prev_state_group"] for row in rows}
+ return dict(rows)
async def purge_room_state(
self, room_id: str, state_groups_to_delete: Collection[int]
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 2500381b7b..cbfb32014c 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -45,6 +45,7 @@ class ProfileInfo:
display_name: Optional[str]
+# TODO This is used as a cached value and is mutable.
@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
class MemberSummary:
# A truncated list of (user_id, event_id) tuples for users of a given
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 5b50bd66bc..158b528dce 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-SCHEMA_VERSION = 82 # remember to update the list below when updating
+SCHEMA_VERSION = 83 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@@ -121,6 +121,9 @@ Changes in SCHEMA_VERSION = 81
Changes in SCHEMA_VERSION = 82
- The insertion_events, insertion_event_extremities, insertion_event_edges, and
batch_events tables are no longer purged in preparation for their removal.
+
+Changes in SCHEMA_VERSION = 83
+ - The event_txn_id is no longer used.
"""
diff --git a/synapse/storage/schema/main/delta/82/04_add_indices_for_purging_rooms.sql b/synapse/storage/schema/main/delta/82/04_add_indices_for_purging_rooms.sql
new file mode 100644
index 0000000000..fc948166e6
--- /dev/null
+++ b/synapse/storage/schema/main/delta/82/04_add_indices_for_purging_rooms.sql
@@ -0,0 +1,20 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (8204, 'e2e_room_keys_index_room_id', '{}');
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (8204, 'room_account_data_index_room_id', '{}');
diff --git a/synapse/storage/schema/main/delta/82/05gaps.sql b/synapse/storage/schema/main/delta/82/05gaps.sql
new file mode 100644
index 0000000000..6813b488ca
--- /dev/null
+++ b/synapse/storage/schema/main/delta/82/05gaps.sql
@@ -0,0 +1,25 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Records when we see a "gap in the timeline", due to missing events over
+-- federation. We record this so that we can tell clients there is a gap (by
+-- marking the timeline section of a sync request as limited).
+CREATE TABLE IF NOT EXISTS timeline_gaps (
+ room_id TEXT NOT NULL,
+ instance_name TEXT NOT NULL,
+ stream_ordering BIGINT NOT NULL
+);
+
+CREATE INDEX timeline_gaps_room_id ON timeline_gaps(room_id, stream_ordering);
diff --git a/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite b/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite
new file mode 100644
index 0000000000..6c7ad0fd37
--- /dev/null
+++ b/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite
@@ -0,0 +1,17 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This already exists on Postgres.
+ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index aa4fa40c9c..52d708ad17 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -134,6 +134,15 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ def get_minimal_local_current_token(self) -> int:
+ """Tries to return a minimal current token for the local instance,
+ i.e. for writers this would be the last successful write.
+
+ If local instance is not a writer (or has written yet) then falls back
+ to returning the normal "current token".
+ """
+
+ @abc.abstractmethod
def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
@@ -312,6 +321,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
def get_current_token_for_writer(self, instance_name: str) -> int:
return self.get_current_token()
+ def get_minimal_local_current_token(self) -> int:
+ return self.get_current_token()
+
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"""Generates and tracks stream IDs for a stream with multiple writers.
@@ -408,6 +420,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# The maximum stream ID that we have seen been allocated across any writer.
self._max_seen_allocated_stream_id = 1
+ # The maximum position of the local instance. This can be higher than
+ # the corresponding position in `current_positions` table when there are
+ # no active writes in progress.
+ self._max_position_of_local_instance = self._max_seen_allocated_stream_id
+
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged.
@@ -427,6 +444,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._current_positions.values(), default=1
)
+ # For the case where `stream_positions` is not up to date,
+ # `_persisted_upto_position` may be higher.
+ self._max_seen_allocated_stream_id = max(
+ self._max_seen_allocated_stream_id, self._persisted_upto_position
+ )
+
+ # Bump our local maximum position now that we've loaded things from the
+ # DB.
+ self._max_position_of_local_instance = self._max_seen_allocated_stream_id
+
if not writers:
# If there have been no explicit writers given then any instance can
# write to the stream. In which case, let's pre-seed our own
@@ -545,6 +572,14 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if instance == self._instance_name:
self._current_positions[instance] = stream_id
+ if self._writers:
+ # If we have explicit writers then make sure that each instance has
+ # a position.
+ for writer in self._writers:
+ self._current_positions.setdefault(
+ writer, self._persisted_upto_position
+ )
+
cur.close()
def _load_next_id_txn(self, txn: Cursor) -> int:
@@ -688,6 +723,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if new_cur:
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, new_cur)
+ self._max_position_of_local_instance = max(
+ curr, new_cur, self._max_position_of_local_instance
+ )
self._add_persisted_position(next_id)
@@ -702,10 +740,26 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# persisted up to position. This stops Synapse from doing a full table
# scan when a new writer announces itself over replication.
with self._lock:
- return self._return_factor * self._current_positions.get(
+ if self._instance_name == instance_name:
+ return self._return_factor * self._max_position_of_local_instance
+
+ pos = self._current_positions.get(
instance_name, self._persisted_upto_position
)
+ # We want to return the maximum "current token" that we can for a
+ # writer, this helps ensure that streams progress as fast as
+ # possible.
+ pos = max(pos, self._persisted_upto_position)
+
+ return self._return_factor * pos
+
+ def get_minimal_local_current_token(self) -> int:
+ with self._lock:
+ return self._return_factor * self._current_positions.get(
+ self._instance_name, self._persisted_upto_position
+ )
+
def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
@@ -774,6 +828,18 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
+ # Advance our local max position.
+ self._max_position_of_local_instance = max(
+ self._max_position_of_local_instance, self._persisted_upto_position
+ )
+
+ if not self._unfinished_ids and not self._in_flight_fetches:
+ # If we don't have anything in flight, it's safe to advance to the
+ # max seen stream ID.
+ self._max_position_of_local_instance = max(
+ self._max_seen_allocated_stream_id, self._max_position_of_local_instance
+ )
+
# We now iterate through the seen positions, discarding those that are
# less than the current min positions, and incrementing the min position
# if its exactly one greater.
|