summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py13
-rw-r--r--synapse/storage/database.py137
-rw-r--r--synapse/storage/databases/main/__init__.py17
-rw-r--r--synapse/storage/databases/main/account_data.py93
-rw-r--r--synapse/storage/databases/main/appservice.py10
-rw-r--r--synapse/storage/databases/main/cache.py9
-rw-r--r--synapse/storage/databases/main/censor_events.py13
-rw-r--r--synapse/storage/databases/main/client_ips.py22
-rw-r--r--synapse/storage/databases/main/deviceinbox.py11
-rw-r--r--synapse/storage/databases/main/devices.py56
-rw-r--r--synapse/storage/databases/main/directory.py10
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py237
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py211
-rw-r--r--synapse/storage/databases/main/event_federation.py32
-rw-r--r--synapse/storage/databases/main/event_push_actions.py273
-rw-r--r--synapse/storage/databases/main/events.py148
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py77
-rw-r--r--synapse/storage/databases/main/events_worker.py4
-rw-r--r--synapse/storage/databases/main/filtering.py4
-rw-r--r--synapse/storage/databases/main/group_server.py10
-rw-r--r--synapse/storage/databases/main/lock.py14
-rw-r--r--synapse/storage/databases/main/media_repository.py3
-rw-r--r--synapse/storage/databases/main/metrics.py27
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py27
-rw-r--r--synapse/storage/databases/main/presence.py11
-rw-r--r--synapse/storage/databases/main/push_rule.py9
-rw-r--r--synapse/storage/databases/main/pusher.py29
-rw-r--r--synapse/storage/databases/main/receipts.py112
-rw-r--r--synapse/storage/databases/main/registration.py24
-rw-r--r--synapse/storage/databases/main/relations.py42
-rw-r--r--synapse/storage/databases/main/room.py226
-rw-r--r--synapse/storage/databases/main/roommember.py23
-rw-r--r--synapse/storage/databases/main/search.py36
-rw-r--r--synapse/storage/databases/main/state.py47
-rw-r--r--synapse/storage/databases/main/state_deltas.py4
-rw-r--r--synapse/storage/databases/main/stats.py103
-rw-r--r--synapse/storage/databases/main/stream.py54
-rw-r--r--synapse/storage/databases/main/tags.py22
-rw-r--r--synapse/storage/databases/main/transactions.py62
-rw-r--r--synapse/storage/databases/main/ui_auth.py15
-rw-r--r--synapse/storage/databases/main/user_directory.py11
-rw-r--r--synapse/storage/schema/__init__.py5
-rw-r--r--synapse/storage/util/id_generators.py6
43 files changed, 1521 insertions, 778 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py

index 3056e64ff5..7967011afd 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py
@@ -17,10 +17,8 @@ import logging from abc import ABCMeta from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union -from synapse.storage.database import LoggingTransaction # noqa: F401 -from synapse.storage.database import make_in_list_sql_clause # noqa: F401 -from synapse.storage.database import DatabasePool -from synapse.storage.types import Connection +from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401 +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.types import get_domain_from_id from synapse.util import json_decoder @@ -38,7 +36,12 @@ class SQLBaseStore(metaclass=ABCMeta): per data store (and not one per physical database). """ - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0693d39006..2cacc7dd6c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -13,8 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging import time +import types from collections import defaultdict from sys import intern from time import monotonic as monotonic_time @@ -53,6 +55,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -175,7 +178,7 @@ class LoggingDatabaseConnection: def rollback(self) -> None: self.conn.rollback() - def __enter__(self) -> "Connection": + def __enter__(self) -> "LoggingDatabaseConnection": self.conn.__enter__() return self @@ -526,6 +529,12 @@ class DatabasePool: the function will correctly handle being aborted and retried half way through its execution. + Similarly, the arguments to `func` (`args`, `kwargs`) should not be generators, + since they could be evaluated multiple times (which would produce an empty + result on the second or subsequent evaluation). Likewise, the closure of `func` + must not reference any generators. This method attempts to detect such usage + and will log an error. + Args: conn desc @@ -536,6 +545,39 @@ class DatabasePool: **kwargs """ + # Robustness check: ensure that none of the arguments are generators, since that + # will fail if we have to repeat the transaction. + # For now, we just log an error, and hope that it works on the first attempt. + # TODO: raise an exception. + for i, arg in enumerate(args): + if inspect.isgenerator(arg): + logger.error( + "Programming error: generator passed to new_transaction as " + "argument %i to function %s", + i, + func, + ) + for name, val in kwargs.items(): + if inspect.isgenerator(val): + logger.error( + "Programming error: generator passed to new_transaction as " + "argument %s to function %s", + name, + func, + ) + # also check variables referenced in func's closure + if inspect.isfunction(func): + f = cast(types.FunctionType, func) + if f.__closure__: + for i, cell in enumerate(f.__closure__): + if inspect.isgenerator(cell.cell_contents): + logger.error( + "Programming error: function %s references generator %s " + "via its closure", + f, + f.__code__.co_freevars[i], + ) + start = monotonic_time() txn_id = self._TXN_ID @@ -896,6 +938,9 @@ class DatabasePool: ) -> None: """Executes an INSERT query on the named table. + The input is given as a list of dicts, with one dict per row. + Generally simple_insert_many_values should be preferred for new code. + Args: table: string giving the table name values: dict of new column names and values for them @@ -909,6 +954,9 @@ class DatabasePool: ) -> None: """Executes an INSERT query on the named table. + The input is given as a list of dicts, with one dict per row. + Generally simple_insert_many_values_txn should be preferred for new code. + Args: txn: The transaction to use. table: string giving the table name @@ -933,23 +981,66 @@ class DatabasePool: if k != keys[0]: raise RuntimeError("All items must have the same keys") + return DatabasePool.simple_insert_many_values_txn(txn, table, keys[0], vals) + + async def simple_insert_many_values( + self, + table: str, + keys: Collection[str], + values: Collection[Collection[Any]], + desc: str, + ) -> None: + """Executes an INSERT query on the named table. + + The input is given as a list of rows, where each row is a list of values. + (Actually any iterable is fine.) + + Args: + table: string giving the table name + keys: list of column names + values: for each row, a list of values in the same order as `keys` + desc: description of the transaction, for logging and metrics + """ + await self.runInteraction( + desc, self.simple_insert_many_values_txn, table, keys, values + ) + + @staticmethod + def simple_insert_many_values_txn( + txn: LoggingTransaction, + table: str, + keys: Collection[str], + values: Iterable[Iterable[Any]], + ) -> None: + """Executes an INSERT query on the named table. + + The input is given as a list of rows, where each row is a list of values. + (Actually any iterable is fine.) + + Args: + txn: The transaction to use. + table: string giving the table name + keys: list of column names + values: for each row, a list of values in the same order as `keys` + """ + if isinstance(txn.database_engine, PostgresEngine): # We use `execute_values` as it can be a lot faster than `execute_batch`, # but it's only available on postgres. sql = "INSERT INTO %s (%s) VALUES ?" % ( table, - ", ".join(k for k in keys[0]), + ", ".join(k for k in keys), ) - txn.execute_values(sql, vals, fetch=False) + txn.execute_values(sql, values, fetch=False) else: sql = "INSERT INTO %s (%s) VALUES(%s)" % ( table, - ", ".join(k for k in keys[0]), - ", ".join("?" for _ in keys[0]), + ", ".join(k for k in keys), + ", ".join("?" for _ in keys), ) - txn.execute_batch(sql, vals) + txn.execute_batch(sql, values) async def simple_upsert( self, @@ -1177,9 +1268,9 @@ class DatabasePool: self, table: str, key_names: Collection[str], - key_values: Collection[Iterable[Any]], + key_values: Collection[Collection[Any]], value_names: Collection[str], - value_values: Iterable[Iterable[Any]], + value_values: Collection[Collection[Any]], desc: str, ) -> None: """ @@ -1337,7 +1428,7 @@ class DatabasePool: self, table: str, keyvalues: Dict[str, Any], - retcols: Iterable[str], + retcols: Collection[str], allow_none: Literal[False] = False, desc: str = "simple_select_one", ) -> Dict[str, Any]: @@ -1348,7 +1439,7 @@ class DatabasePool: self, table: str, keyvalues: Dict[str, Any], - retcols: Iterable[str], + retcols: Collection[str], allow_none: Literal[True] = True, desc: str = "simple_select_one", ) -> Optional[Dict[str, Any]]: @@ -1358,7 +1449,7 @@ class DatabasePool: self, table: str, keyvalues: Dict[str, Any], - retcols: Iterable[str], + retcols: Collection[str], allow_none: bool = False, desc: str = "simple_select_one", ) -> Optional[Dict[str, Any]]: @@ -1528,7 +1619,7 @@ class DatabasePool: self, table: str, keyvalues: Optional[Dict[str, Any]], - retcols: Iterable[str], + retcols: Collection[str], desc: str = "simple_select_list", ) -> List[Dict[str, Any]]: """Executes a SELECT query on the named table, which may return zero or @@ -1591,7 +1682,7 @@ class DatabasePool: table: str, column: str, iterable: Iterable[Any], - retcols: Iterable[str], + retcols: Collection[str], keyvalues: Optional[Dict[str, Any]] = None, desc: str = "simple_select_many_batch", batch_size: int = 100, @@ -1614,16 +1705,7 @@ class DatabasePool: results: List[Dict[str, Any]] = [] - if not iterable: - return results - - # iterables can not be sliced, so convert it to a list first - it_list = list(iterable) - - chunks = [ - it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) - ] - for chunk in chunks: + for chunk in batch_iter(iterable, batch_size): rows = await self.runInteraction( desc, self.simple_select_many_txn, @@ -1763,7 +1845,7 @@ class DatabasePool: txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], - retcols: Iterable[str], + retcols: Collection[str], allow_none: bool = False, ) -> Optional[Dict[str, Any]]: select_sql = "SELECT %s FROM %s WHERE %s" % ( @@ -1871,7 +1953,7 @@ class DatabasePool: self, table: str, column: str, - iterable: Iterable[Any], + iterable: Collection[Any], keyvalues: Dict[str, Any], desc: str, ) -> int: @@ -1882,7 +1964,8 @@ class DatabasePool: Args: table: string giving the table name column: column name to test for inclusion against `iterable` - iterable: list + iterable: list of values to match against `column`. NB cannot be a generator + as it may be evaluated multiple times. keyvalues: dict of column names and values to select the rows with desc: description of the transaction, for logging and metrics @@ -2055,7 +2138,7 @@ class DatabasePool: table: str, term: Optional[str], col: str, - retcols: Iterable[str], + retcols: Collection[str], desc="simple_search_list", ) -> Optional[List[Dict[str, Any]]]: """Executes a SELECT query on the named table, which may return zero or diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 9ff2d8d8c3..f024761ba7 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -18,7 +18,7 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.config.homeserver import HomeServerConfig -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( @@ -68,7 +68,7 @@ from .session import SessionStore from .signatures import SignatureStore from .state import StateStore from .stats import StatsStore -from .stream import StreamStore +from .stream import StreamWorkerStore from .tags import TagsStore from .transactions import TransactionWorkerStore from .ui_auth import UIAuthStore @@ -87,7 +87,7 @@ class DataStore( RoomStore, RoomBatchStore, RegistrationStore, - StreamStore, + StreamWorkerStore, ProfileStore, PresenceStore, TransactionWorkerStore, @@ -129,7 +129,12 @@ class DataStore( LockStore, SessionStore, ): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine @@ -143,11 +148,7 @@ class DataStore( ("device_lists_outbound_pokes", "stream_id"), ], ) - self._cross_signing_id_gen = StreamIdGenerator( - db_conn, "e2e_cross_signing_keys", "stream_id" - ) - self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._group_updates_id_gen = StreamIdGenerator( diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index f8bec266ac..32a553fdd7 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py
@@ -14,15 +14,25 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast from synapse.api.constants import AccountDataTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage._base import db_to_json +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + AbstractStreamIdTracker, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -34,13 +44,19 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class AccountDataWorkerStore(SQLBaseStore): - """This is an abstract base class where subclasses must implement - `get_max_account_data_stream_id` which can be called in the initializer. - """ +class AccountDataWorkerStore(CacheInvalidationWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): - self._instance_name = hs.get_instance_name() + # `_can_write_to_account_data` indicates whether the current worker is allowed + # to write account data. A value of `True` implies that `_account_data_id_gen` + # is an `AbstractStreamIdGenerator` and not just a tracker. + self._account_data_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): self._can_write_to_account_data = ( @@ -61,8 +77,6 @@ class AccountDataWorkerStore(SQLBaseStore): writers=hs.config.worker.writers.account_data, ) else: - self._can_write_to_account_data = True - # We shouldn't be running in worker mode with SQLite, but its useful # to support it for unit tests. # @@ -70,7 +84,8 @@ class AccountDataWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.account_data: + if self._instance_name in hs.config.worker.writers.account_data: + self._can_write_to_account_data = True self._account_data_id_gen = StreamIdGenerator( db_conn, "room_account_data", @@ -90,8 +105,6 @@ class AccountDataWorkerStore(SQLBaseStore): "AccountDataAndTagsChangeCache", account_max ) - super().__init__(database, db_conn, hs) - def get_max_account_data_stream_id(self) -> int: """Get the current max stream ID for account data stream @@ -113,7 +126,9 @@ class AccountDataWorkerStore(SQLBaseStore): room_id string to per room account_data dicts. """ - def get_account_data_for_user_txn(txn): + def get_account_data_for_user_txn( + txn: LoggingTransaction, + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: rows = self.db_pool.simple_select_list_txn( txn, "account_data", @@ -132,7 +147,7 @@ class AccountDataWorkerStore(SQLBaseStore): ["room_id", "account_data_type", "content"], ) - by_room = {} + by_room: Dict[str, Dict[str, JsonDict]] = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) room_data[row["account_data_type"]] = db_to_json(row["content"]) @@ -177,7 +192,9 @@ class AccountDataWorkerStore(SQLBaseStore): A dict of the room account_data """ - def get_account_data_for_room_txn(txn): + def get_account_data_for_room_txn( + txn: LoggingTransaction, + ) -> Dict[str, JsonDict]: rows = self.db_pool.simple_select_list_txn( txn, "room_account_data", @@ -207,7 +224,9 @@ class AccountDataWorkerStore(SQLBaseStore): The room account_data for that type, or None if there isn't any set. """ - def get_account_data_for_room_and_type_txn(txn): + def get_account_data_for_room_and_type_txn( + txn: LoggingTransaction, + ) -> Optional[JsonDict]: content_json = self.db_pool.simple_select_one_onecol_txn( txn, table="room_account_data", @@ -243,14 +262,16 @@ class AccountDataWorkerStore(SQLBaseStore): if last_id == current_id: return [] - def get_updated_global_account_data_txn(txn): + def get_updated_global_account_data_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str]]: sql = ( "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + return cast(List[Tuple[int, str, str]], txn.fetchall()) return await self.db_pool.runInteraction( "get_updated_global_account_data", get_updated_global_account_data_txn @@ -273,14 +294,16 @@ class AccountDataWorkerStore(SQLBaseStore): if last_id == current_id: return [] - def get_updated_room_account_data_txn(txn): + def get_updated_room_account_data_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str, str]]: sql = ( "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + return cast(List[Tuple[int, str, str, str]], txn.fetchall()) return await self.db_pool.runInteraction( "get_updated_room_account_data", get_updated_room_account_data_txn @@ -299,7 +322,9 @@ class AccountDataWorkerStore(SQLBaseStore): mapping from room_id string to per room account_data dicts. """ - def get_updated_account_data_for_user_txn(txn): + def get_updated_account_data_for_user_txn( + txn: LoggingTransaction, + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: sql = ( "SELECT account_data_type, content FROM account_data" " WHERE user_id = ? AND stream_id > ?" @@ -316,7 +341,7 @@ class AccountDataWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, stream_id)) - account_data_by_room = {} + account_data_by_room: Dict[str, Dict[str, JsonDict]] = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = db_to_json(row[2]) @@ -353,12 +378,15 @@ class AccountDataWorkerStore(SQLBaseStore): ) ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: if stream_name == TagAccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) - for row in rows: - self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) elif stream_name == AccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) for row in rows: @@ -372,7 +400,8 @@ class AccountDataWorkerStore(SQLBaseStore): (row.user_id, row.room_id, row.data_type) ) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + + super().process_replication_rows(stream_name, instance_name, token, rows) async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict @@ -389,6 +418,7 @@ class AccountDataWorkerStore(SQLBaseStore): The maximum stream ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) content_json = json_encoder.encode(content) @@ -431,6 +461,7 @@ class AccountDataWorkerStore(SQLBaseStore): The maximum stream ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction( @@ -452,7 +483,7 @@ class AccountDataWorkerStore(SQLBaseStore): def _add_account_data_for_user( self, - txn, + txn: LoggingTransaction, next_id: int, user_id: str, account_data_type: str, diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 4a883dc166..92c95a41d7 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py
@@ -24,9 +24,8 @@ from synapse.appservice import ( from synapse.config.appservice import load_appservices from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.types import Connection from synapse.types import JsonDict from synapse.util import json_encoder @@ -58,7 +57,12 @@ def _make_exclusive_regex( class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self.services_cache = load_appservices( hs.hostname, hs.config.appservice.app_service_config_files ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 36e8422fc6..0024348067 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py
@@ -25,7 +25,7 @@ from synapse.replication.tcp.streams.events import ( EventsStreamEventRow, ) from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.engines import PostgresEngine from synapse.util.iterutils import batch_iter @@ -41,7 +41,12 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" class CacheInvalidationWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 0f56e10220..fd3fc298b3 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py
@@ -18,7 +18,11 @@ from typing import TYPE_CHECKING, Optional from synapse.events.utils import prune_event_dict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.util import json_encoder @@ -31,7 +35,12 @@ logger = logging.getLogger(__name__) class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if ( diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index a6fd9f2636..f3881671fd 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py
@@ -26,7 +26,6 @@ from synapse.storage.database import ( make_tuple_comparison_clause, ) from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore -from synapse.storage.types import Connection from synapse.types import JsonDict, UserID from synapse.util.caches.lrucache import LruCache @@ -65,7 +64,12 @@ class LastConnectionInfo(TypedDict): class ClientIpBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -394,7 +398,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.server.user_ips_max_age @@ -532,7 +541,12 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): # (user_id, access_token, ip,) -> last_seen self.client_ip_last_seen = LruCache[Tuple[str, str, str], int]( diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index ab8766c75b..3682cb6a81 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, cast from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace @@ -601,7 +601,12 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox" REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -668,7 +673,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): # There's a type mismatch here between how we want to type the row and # what fetchone says it returns, but we silence it because we know that # res can't be None. - res: Tuple[Optional[int]] = txn.fetchone() # type: ignore[assignment] + res = cast(Tuple[Optional[int]], txn.fetchone()) if res[0] is None: # this can only happen if the `device_inbox` table is empty, in which # case we have no work to do. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d5a4a661cd..273adb61fd 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -38,6 +38,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_tuple_comparison_clause, ) @@ -61,7 +62,12 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -101,7 +107,9 @@ class DeviceWorkerStore(SQLBaseStore): "count_devices_by_users", count_devices_by_users_txn, user_ids ) - async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: + async def get_device( + self, user_id: str, device_id: str + ) -> Optional[Dict[str, Any]]: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -109,15 +117,35 @@ class DeviceWorkerStore(SQLBaseStore): user_id: The ID of the user which owns the device device_id: The ID of the device to retrieve Returns: - A dict containing the device information - Raises: - StoreError: if the device is not found + A dict containing the device information, or `None` if the device does not + exist. """ return await self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), desc="get_device", + allow_none=True, + ) + + async def get_device_opt( + self, user_id: str, device_id: str + ) -> Optional[Dict[str, Any]]: + """Retrieve a device. Only returns devices that are not marked as + hidden. + + Args: + user_id: The ID of the user which owns the device + device_id: The ID of the device to retrieve + Returns: + A dict containing the device information, or None if the device does not exist. + """ + return await self.db_pool.simple_select_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, + retcols=("user_id", "device_id", "display_name"), + desc="get_device", + allow_none=True, ) async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: @@ -274,7 +302,9 @@ class DeviceWorkerStore(SQLBaseStore): # add the updated cross-signing keys to the results list for user_id, result in cross_signing_keys_by_user.items(): result["user_id"] = user_id - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + results.append(("m.signing_key_update", result)) + # also send the unstable version + # FIXME: remove this when enough servers have upgraded results.append(("org.matrix.signing_key_update", result)) return now_stream_id, results @@ -949,7 +979,12 @@ class DeviceWorkerStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -1081,7 +1116,12 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index a3442814d7..f76c6121e8 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py
@@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import Iterable, List, Optional, Tuple +import attr + from synapse.api.errors import SynapseError from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.types import RoomAlias from synapse.util.caches.descriptors import cached -RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomAliasMapping: + room_id: str + room_alias: str + servers: List[str] class DirectoryWorkerStore(CacheInvalidationWorkerStore): diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index b15fb71e62..0cb48b9dd7 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -13,35 +13,71 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Dict, Iterable, Mapping, Optional, Tuple, cast + +from typing_extensions import Literal, TypedDict from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import LoggingTransaction +from synapse.types import JsonDict, JsonSerializable from synapse.util import json_encoder +class RoomKey(TypedDict): + """`KeyBackupData` in the Matrix spec. + + https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid + """ + + first_message_index: int + forwarded_count: int + is_verified: bool + session_data: JsonSerializable + + class EndToEndRoomKeyStore(SQLBaseStore): + """The store for end to end room key backups. + + See https://spec.matrix.org/v1.1/client-server-api/#server-side-key-backups + + As per the spec, backups are identified by an opaque version string. Internally, + version identifiers are assigned using incrementing integers. Non-numeric version + strings are treated as if they do not exist, since we would have never issued them. + """ + async def update_e2e_room_key( - self, user_id, version, room_id, session_id, room_key - ): + self, + user_id: str, + version: str, + room_id: str, + session_id: str, + room_key: RoomKey, + ) -> None: """Replaces the encrypted E2E room key for a given session in a given backup Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_id(str): the ID of the room whose keys we're setting - session_id(str): the session whose room_key we're setting - room_key(dict): the room_key being set + user_id: the user whose backup we're setting + version: the version ID of the backup we're updating + room_id: the ID of the room whose keys we're setting + session_id: the session whose room_key we're setting + room_key: the room_key being set Raises: StoreError """ + try: + version_int = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it doesn't exist. + raise StoreError(404, "No backup with that version exists") await self.db_pool.simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, - "version": version, + "version": version_int, "room_id": room_id, "session_id": session_id, }, @@ -54,22 +90,29 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="update_e2e_room_key", ) - async def add_e2e_room_keys(self, user_id, version, room_keys): + async def add_e2e_room_keys( + self, user_id: str, version: str, room_keys: Iterable[Tuple[str, str, RoomKey]] + ) -> None: """Bulk add room keys to a given backup. Args: - user_id (str): the user whose backup we're adding to - version (str): the version ID of the backup for the set of keys we're adding to - room_keys (iterable[(str, str, dict)]): the keys to add, in the form - (roomID, sessionID, keyData) + user_id: the user whose backup we're adding to + version: the version ID of the backup for the set of keys we're adding to + room_keys: the keys to add, in the form (roomID, sessionID, keyData) """ + try: + version_int = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it doesn't exist. + raise StoreError(404, "No backup with that version exists") values = [] for (room_id, session_id, room_key) in room_keys: values.append( { "user_id": user_id, - "version": version, + "version": version_int, "room_id": room_id, "session_id": session_id, "first_message_index": room_key["first_message_index"], @@ -92,31 +135,39 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_e2e_room_keys( + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> Dict[ + Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]] + ]: """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. Args: - user_id (str): the user whose backup we're querying - version (str): the version ID of the backup for the set of keys we're querying - room_id (str): Optional. the ID of the room whose keys we're querying, if any. + user_id: the user whose backup we're querying + version: the version ID of the backup for the set of keys we're querying + room_id: Optional. the ID of the room whose keys we're querying, if any. If not specified, we return the keys for all the rooms in the backup. - session_id (str): Optional. the session whose room_key we're querying, if any. + session_id: Optional. the session whose room_key we're querying, if any. If specified, we also require the room_id to be specified. If not specified, we return all the keys in this version of the backup (or for the specified room) Returns: - A list of dicts giving the session_data and message metadata for - these room keys. + A dict giving the session_data and message metadata for these room keys. + `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}` """ try: - version = int(version) + version_int = int(version) except ValueError: return {"rooms": {}} - keyvalues = {"user_id": user_id, "version": version} + keyvalues = {"user_id": user_id, "version": version_int} if room_id: keyvalues["room_id"] = room_id if session_id: @@ -137,7 +188,9 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="get_e2e_room_keys", ) - sessions = {"rooms": {}} + sessions: Dict[ + Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]] + ] = {"rooms": {}} for row in rows: room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}}) room_entry["sessions"][row["session_id"]] = { @@ -150,7 +203,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): return sessions - async def get_e2e_room_keys_multi(self, user_id, version, room_keys): + async def get_e2e_room_keys_multi( + self, + user_id: str, + version: str, + room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]], + ) -> Dict[str, Dict[str, RoomKey]]: """Get multiple room keys at a time. The difference between this function and get_e2e_room_keys is that this function can be used to retrieve multiple specific keys at a time, whereas get_e2e_room_keys is used for @@ -158,26 +216,36 @@ class EndToEndRoomKeyStore(SQLBaseStore): specific key. Args: - user_id (str): the user whose backup we're querying - version (str): the version ID of the backup we're querying about - room_keys (dict[str, dict[str, iterable[str]]]): a map from - room ID -> {"session": [session ids]} indicating the session IDs - that we want to query + user_id: the user whose backup we're querying + version: the version ID of the backup we're querying about + room_keys: a map from room ID -> {"sessions": [session ids]} + indicating the session IDs that we want to query Returns: - dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key + A map of room IDs to session IDs to room key """ + try: + version_int = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it doesn't exist. + return {} return await self.db_pool.runInteraction( "get_e2e_room_keys_multi", self._get_e2e_room_keys_multi_txn, user_id, - version, + version_int, room_keys, ) @staticmethod - def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys): + def _get_e2e_room_keys_multi_txn( + txn: LoggingTransaction, + user_id: str, + version: int, + room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]], + ) -> Dict[str, Dict[str, RoomKey]]: if not room_keys: return {} @@ -209,7 +277,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): txn.execute(sql, params) - ret = {} + ret: Dict[str, Dict[str, RoomKey]] = {} for row in txn: room_id = row[0] @@ -231,36 +299,49 @@ class EndToEndRoomKeyStore(SQLBaseStore): user_id: the user whose backup we're querying version: the version ID of the backup we're querying about """ + try: + version_int = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it doesn't exist. + return 0 return await self.db_pool.simple_select_one_onecol( table="e2e_room_keys", - keyvalues={"user_id": user_id, "version": version}, + keyvalues={"user_id": user_id, "version": version_int}, retcol="COUNT(*)", desc="count_e2e_room_keys", ) @trace async def delete_e2e_room_keys( - self, user_id, version, room_id=None, session_id=None - ): + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> None: """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. Args: - user_id(str): the user whose backup we're deleting from - version(str): the version ID of the backup for the set of keys we're deleting - room_id(str): Optional. the ID of the room whose keys we're deleting, if any. + user_id: the user whose backup we're deleting from + version: the version ID of the backup for the set of keys we're deleting + room_id: Optional. the ID of the room whose keys we're deleting, if any. If not specified, we delete the keys for all the rooms in the backup. - session_id(str): Optional. the session whose room_key we're querying, if any. + session_id: Optional. the session whose room_key we're querying, if any. If specified, we also require the room_id to be specified. If not specified, we delete all the keys in this version of the backup (or for the specified room) - - Returns: - The deletion transaction """ + try: + version_int = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it doesn't exist. + return - keyvalues = {"user_id": user_id, "version": int(version)} + keyvalues = {"user_id": user_id, "version": version_int} if room_id: keyvalues["room_id"] = room_id if session_id: @@ -271,23 +352,27 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @staticmethod - def _get_current_version(txn, user_id): + def _get_current_version(txn: LoggingTransaction, user_id: str) -> int: txn.execute( "SELECT MAX(version) FROM e2e_room_keys_versions " "WHERE user_id=? AND deleted=0", (user_id,), ) - row = txn.fetchone() - if not row: + # `SELECT MAX() FROM ...` will always return 1 row. The value in that row will + # be `NULL` when there are no available versions. + row = cast(Tuple[Optional[int]], txn.fetchone()) + if row[0] is None: raise StoreError(404, "No current backup version") return row[0] - async def get_e2e_room_keys_version_info(self, user_id, version=None): + async def get_e2e_room_keys_version_info( + self, user_id: str, version: Optional[str] = None + ) -> JsonDict: """Get info metadata about a version of our room_keys backup. Args: - user_id(str): the user whose backup we're querying - version(str): Optional. the version ID of the backup we're querying about + user_id: the user whose backup we're querying + version: Optional. the version ID of the backup we're querying about If missing, we return the information about the current version. Raises: StoreError: with code 404 if there are no e2e_room_keys_versions present @@ -300,7 +385,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): etag(int): tag of the keys in the backup """ - def _get_e2e_room_keys_version_info_txn(txn): + def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict: if version is None: this_version = self._get_current_version(txn, user_id) else: @@ -309,14 +394,16 @@ class EndToEndRoomKeyStore(SQLBaseStore): except ValueError: # Our versions are all ints so if we can't convert it to an integer, # it isn't there. - raise StoreError(404, "No row found") + raise StoreError(404, "No backup with that version exists") result = self.db_pool.simple_select_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, retcols=("version", "algorithm", "auth_data", "etag"), + allow_none=False, ) + assert result is not None # see comment on `simple_select_one_txn` result["auth_data"] = db_to_json(result["auth_data"]) result["version"] = str(result["version"]) if result["etag"] is None: @@ -328,28 +415,28 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str: + async def create_e2e_room_keys_version(self, user_id: str, info: JsonDict) -> str: """Atomically creates a new version of this user's e2e_room_keys store with the given version info. Args: - user_id(str): the user whose backup we're creating a version - info(dict): the info about the backup version to be created + user_id: the user whose backup we're creating a version + info: the info about the backup version to be created Returns: The newly created version ID """ - def _create_e2e_room_keys_version_txn(txn): + def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str: txn.execute( "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", (user_id,), ) - current_version = txn.fetchone()[0] + current_version = cast(Tuple[Optional[int]], txn.fetchone())[0] if current_version is None: - current_version = "0" + current_version = 0 - new_version = str(int(current_version) + 1) + new_version = current_version + 1 self.db_pool.simple_insert_txn( txn, @@ -362,7 +449,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): }, ) - return new_version + return str(new_version) return await self.db_pool.runInteraction( "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn @@ -373,7 +460,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): self, user_id: str, version: str, - info: Optional[dict] = None, + info: Optional[JsonDict] = None, version_etag: Optional[int] = None, ) -> None: """Update a given backup version @@ -386,7 +473,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): version_etag: etag of the keys in the backup. If None, then the etag is not updated. """ - updatevalues = {} + updatevalues: Dict[str, object] = {} if info is not None and "auth_data" in info: updatevalues["auth_data"] = json_encoder.encode(info["auth_data"]) @@ -394,9 +481,16 @@ class EndToEndRoomKeyStore(SQLBaseStore): updatevalues["etag"] = version_etag if updatevalues: - await self.db_pool.simple_update( + try: + version_int = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it doesn't exist. + raise StoreError(404, "No backup with that version exists") + + await self.db_pool.simple_update_one( table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": version}, + keyvalues={"user_id": user_id, "version": version_int}, updatevalues=updatevalues, desc="update_e2e_room_keys_version", ) @@ -417,13 +511,16 @@ class EndToEndRoomKeyStore(SQLBaseStore): or if the version requested doesn't exist. """ - def _delete_e2e_room_keys_version_txn(txn): + def _delete_e2e_room_keys_version_txn(txn: LoggingTransaction) -> None: if version is None: this_version = self._get_current_version(txn, user_id) - if this_version is None: - raise StoreError(404, "No current backup version") else: - this_version = version + try: + this_version = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it isn't there. + raise StoreError(404, "No backup with that version exists") self.db_pool.simple_delete_txn( txn, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b06c1dc45b..57b5ffbad3 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,19 +14,32 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Tuple, + cast, +) import attr from canonicaljson import encode_canonical_json -from twisted.enterprise.adbapi import Connection - from synapse.api.constants import DeviceKeyAlgorithms from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine -from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -50,7 +63,12 @@ class DeviceKeyLookupResult: class EndToEndKeyBackgroundStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -62,8 +80,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore): ) -class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): +class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._allow_device_name_lookup_over_federation = ( @@ -124,7 +147,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): # Build the result structure, un-jsonify the results, and add the # "unsigned" section - rv = {} + rv: Dict[str, Dict[str, JsonDict]] = {} for user_id, device_keys in results.items(): rv[user_id] = {} for device_id, device_info in device_keys.items(): @@ -195,6 +218,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): # add each cross-signing signature to the correct device in the result dict. for (user_id, key_id, device_id, signature) in cross_sigs_result: target_device_result = result[user_id][device_id] + # We've only looked up cross-signatures for non-deleted devices with key + # data. + assert target_device_result is not None + assert target_device_result.keys is not None target_device_signatures = target_device_result.keys.setdefault( "signatures", {} ) @@ -207,7 +234,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): return result def _get_e2e_device_keys_txn( - self, txn, query_list, include_all_devices=False, include_deleted_devices=False + self, + txn: LoggingTransaction, + query_list: Collection[Tuple[str, str]], + include_all_devices: bool = False, + include_deleted_devices: bool = False, ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: """Get information on devices from the database @@ -263,7 +294,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): return result def _get_e2e_cross_signing_signatures_for_devices_txn( - self, txn: Cursor, device_query: Iterable[Tuple[str, str]] + self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]] ) -> List[Tuple[str, str, str, str]]: """Get cross-signing signatures for a given list of devices @@ -289,7 +320,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): ) txn.execute(signature_sql, signature_query_params) - return txn.fetchall() + return cast( + List[ + Tuple[ + str, + str, + str, + str, + ] + ], + txn.fetchall(), + ) async def get_e2e_one_time_keys( self, user_id: str, device_id: str, key_ids: List[str] @@ -335,7 +376,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): new_keys: keys to add - each a tuple of (algorithm, key_id, key json) """ - def _add_e2e_one_time_keys(txn): + def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None: set_tag("user_id", user_id) set_tag("device_id", device_id) set_tag("new_keys", new_keys) @@ -375,7 +416,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): A mapping from algorithm to number of keys for that algorithm. """ - def _count_e2e_one_time_keys(txn): + def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]: sql = ( "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ?" @@ -421,7 +462,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): ) def _set_e2e_fallback_keys_txn( - self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + fallback_keys: JsonDict, ) -> None: # fallback_keys will usually only have one item in it, so using a for # loop (as opposed to calling simple_upsert_many_txn) won't be too bad @@ -483,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): async def get_e2e_cross_signing_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None - ) -> Optional[dict]: + ) -> Optional[JsonDict]: """Returns a user's cross-signing key. Args: @@ -504,7 +549,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): return user_keys.get(key_type) @cached(num_args=1) - def _get_bare_e2e_cross_signing_keys(self, user_id): + def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]: """Dummy function. Only used to make a cache for _get_bare_e2e_cross_signing_keys_bulk. """ @@ -517,7 +562,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): ) async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: Iterable[str] - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Optional[Dict[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. @@ -531,32 +576,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): their user ID will map to None. """ - return await self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( "get_bare_e2e_cross_signing_keys_bulk", self._get_bare_e2e_cross_signing_keys_bulk_txn, user_ids, ) + # The `Optional` comes from the `@cachedList` decorator. + return cast(Dict[str, Optional[Dict[str, JsonDict]]], result) + def _get_bare_e2e_cross_signing_keys_bulk_txn( self, - txn: Connection, + txn: LoggingTransaction, user_ids: Iterable[str], - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Dict[str, JsonDict]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. Args: - txn (twisted.enterprise.adbapi.Connection): db connection - user_ids (list[str]): the users whose keys are being requested + txn: db connection + user_ids: the users whose keys are being requested Returns: - dict[str, dict[str, dict]]: mapping from user ID to key type to key - data. If a user's cross-signing keys were not found, their user - ID will not be in the dict. + Mapping from user ID to key type to key data. + If a user's cross-signing keys were not found, their user ID will not be in + the dict. """ - result = {} + result: Dict[str, Dict[str, JsonDict]] = {} for user_chunk in batch_iter(user_ids, 100): clause, params = make_in_list_sql_clause( @@ -596,43 +644,48 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): user_id = row["user_id"] key_type = row["keytype"] key = db_to_json(row["keydata"]) - user_info = result.setdefault(user_id, {}) - user_info[key_type] = key + user_keys = result.setdefault(user_id, {}) + user_keys[key_type] = key return result def _get_e2e_cross_signing_signatures_txn( self, - txn: Connection, - keys: Dict[str, Dict[str, dict]], + txn: LoggingTransaction, + keys: Dict[str, Optional[Dict[str, JsonDict]]], from_user_id: str, - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Optional[Dict[str, JsonDict]]]: """Returns the cross-signing signatures made by a user on a set of keys. Args: - txn (twisted.enterprise.adbapi.Connection): db connection - keys (dict[str, dict[str, dict]]): a map of user ID to key type to - key data. This dict will be modified to add signatures. - from_user_id (str): fetch the signatures made by this user + txn: db connection + keys: a map of user ID to key type to key data. + This dict will be modified to add signatures. + from_user_id: fetch the signatures made by this user Returns: - dict[str, dict[str, dict]]: mapping from user ID to key type to key - data. The return value will be the same as the keys argument, - with the modifications included. + Mapping from user ID to key type to key data. + The return value will be the same as the keys argument, with the + modifications included. """ # find out what cross-signing keys (a.k.a. devices) we need to get # signatures for. This is a map of (user_id, device_id) to key type # (device_id is the key's public part). - devices = {} + devices: Dict[Tuple[str, str], str] = {} - for user_id, user_info in keys.items(): - if user_info is None: + for user_id, user_keys in keys.items(): + if user_keys is None: continue - for key_type, key in user_info.items(): + for key_type, key in user_keys.items(): device_id = None for k in key["keys"].values(): device_id = k + # `key` ought to be a `CrossSigningKey`, whose .keys property is a + # dictionary with a single entry: + # "algorithm:base64_public_key": "base64_public_key" + # See https://spec.matrix.org/v1.1/client-server-api/#cross-signing + assert isinstance(device_id, str) devices[(user_id, device_id)] = key_type for batch in batch_iter(devices.keys(), size=100): @@ -656,15 +709,20 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): # and add the signatures to the appropriate keys for row in rows: - key_id = row["key_id"] - target_user_id = row["target_user_id"] - target_device_id = row["target_device_id"] + key_id: str = row["key_id"] + target_user_id: str = row["target_user_id"] + target_device_id: str = row["target_device_id"] key_type = devices[(target_user_id, target_device_id)] # We need to copy everything, because the result may have come # from the cache. dict.copy only does a shallow copy, so we # need to recursively copy the dicts that will be modified. - user_info = keys[target_user_id] = keys[target_user_id].copy() - target_user_key = user_info[key_type] = user_info[key_type].copy() + user_keys = keys[target_user_id] + # `user_keys` cannot be `None` because we only fetched signatures for + # users with keys + assert user_keys is not None + user_keys = keys[target_user_id] = user_keys.copy() + + target_user_key = user_keys[key_type] = user_keys[key_type].copy() if "signatures" in target_user_key: signatures = target_user_key["signatures"] = target_user_key[ "signatures" @@ -683,7 +741,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None - ) -> Dict[str, Optional[Dict[str, dict]]]: + ) -> Dict[str, Optional[Dict[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. Args: @@ -741,7 +799,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): if last_id == current_id: return [], current_id, False - def _get_all_user_signature_changes_for_remotes_txn(txn): + def _get_all_user_signature_changes_for_remotes_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: sql = """ SELECT stream_id, from_user_id AS user_id FROM user_signature_stream @@ -785,7 +845,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): @trace def _claim_e2e_one_time_key_simple( - txn, user_id: str, device_id: str, algorithm: str + txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str ) -> Optional[Tuple[str, str]]: """Claim OTK for device for DBs that don't support RETURNING. @@ -825,7 +885,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): @trace def _claim_e2e_one_time_key_returning( - txn, user_id: str, device_id: str, algorithm: str + txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str ) -> Optional[Tuple[str, str]]: """Claim OTK for device for DBs that support RETURNING. @@ -860,7 +920,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): key_id, key_json = otk_row return f"{algorithm}:{key_id}", key_json - results = {} + results: Dict[str, Dict[str, Dict[str, str]]] = {} for user_id, device_id, algorithm in query_list: if self.database_engine.supports_returning: # If we support RETURNING clause we can use a single query that @@ -930,6 +990,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._cross_signing_id_gen = StreamIdGenerator( + db_conn, "e2e_cross_signing_keys", "stream_id" + ) + async def set_e2e_device_keys( self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict ) -> bool: @@ -937,7 +1009,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): or the keys were already in the database. """ - def _set_e2e_device_keys_txn(txn): + def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool: set_tag("user_id", user_id) set_tag("device_id", device_id) set_tag("time_now", time_now) @@ -973,7 +1045,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: - def delete_e2e_keys_by_device_txn(txn): + def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: log_kv( { "message": "Deleting keys for device", @@ -1012,17 +1084,24 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) - def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id): + def _set_e2e_cross_signing_key_txn( + self, + txn: LoggingTransaction, + user_id: str, + key_type: str, + key: JsonDict, + stream_id: int, + ) -> None: """Set a user's cross-signing key. Args: - txn (twisted.enterprise.adbapi.Connection): db connection - user_id (str): the user to set the signing key for - key_type (str): the type of key that is being set: either 'master' + txn: db connection + user_id: the user to set the signing key for + key_type: the type of key that is being set: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key - key (dict): the key data - stream_id (int) + key: the key data + stream_id """ # the 'key' dict will look something like: # { @@ -1075,13 +1154,15 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): txn, self._get_bare_e2e_cross_signing_keys, (user_id,) ) - async def set_e2e_cross_signing_key(self, user_id, key_type, key): + async def set_e2e_cross_signing_key( + self, user_id: str, key_type: str, key: JsonDict + ) -> None: """Set a user's cross-signing key. Args: - user_id (str): the user to set the user-signing key for - key_type (str): the type of cross-signing key to set - key (dict): the key data + user_id: the user to set the user-signing key for + key_type: the type of cross-signing key to set + key: the key data """ async with self._cross_signing_id_gen.get_next() as stream_id: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 9580a40785..270b30800b 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -24,7 +24,11 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.events import EventBase, make_event_from_dict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.engines import PostgresEngine @@ -62,7 +66,12 @@ class _NoChainCoverIndex(Exception): class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -279,7 +288,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas new_front = set() for chunk in batch_iter(front, 100): # Pull the auth events either from the cache or DB. - to_fetch = [] # Event IDs to fetch from DB # type: List[str] + to_fetch: List[str] = [] # Event IDs to fetch from DB for event_id in chunk: res = self._event_auth_cache.get(event_id) if res is None: @@ -606,8 +615,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # currently walking, either from cache or DB. search, chunk = search[:-100], search[-100:] - found = [] # Results found # type: List[Tuple[str, str, int]] - to_fetch = [] # Event IDs to fetch from DB # type: List[str] + found: List[Tuple[str, str, int]] = [] # Results found + to_fetch: List[str] = [] # Event IDs to fetch from DB for _, event_id in chunk: res = self._event_auth_cache.get(event_id) if res is None: @@ -1384,7 +1393,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas count = await self.db_pool.simple_select_one_onecol( table="federation_inbound_events_staging", keyvalues={"room_id": room_id}, - retcol="COALESCE(COUNT(*), 0)", + retcol="COUNT(*)", desc="prune_staged_events_in_room_count", ) @@ -1476,9 +1485,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """Update the prometheus metrics for the inbound federation staging area.""" def _get_stats_for_federation_staging_txn(txn): - txn.execute( - "SELECT coalesce(count(*), 0) FROM federation_inbound_events_staging" - ) + txn.execute("SELECT count(*) FROM federation_inbound_events_staging") (count,) = txn.fetchone() txn.execute( @@ -1514,7 +1521,12 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3efdd0c920..a98e6b2593 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,14 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast import attr -from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -30,29 +33,64 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}] -DEFAULT_HIGHLIGHT_ACTION = [ +DEFAULT_NOTIF_ACTION: List[Union[dict, str]] = [ + "notify", + {"set_tweak": "highlight", "value": False}, +] +DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [ "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}, ] -class BasePushAction(TypedDict): - event_id: str - actions: List[Union[dict, str]] - +@attr.s(slots=True, frozen=True, auto_attribs=True) +class HttpPushAction: + """ + HttpPushAction instances include the information used to generate HTTP + requests to a push gateway. + """ -class HttpPushAction(BasePushAction): + event_id: str room_id: str stream_ordering: int + actions: List[Union[dict, str]] +@attr.s(slots=True, frozen=True, auto_attribs=True) class EmailPushAction(HttpPushAction): + """ + EmailPushAction instances include the information used to render an email + push notification. + """ + received_ts: Optional[int] -def _serialize_action(actions, is_highlight): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UserPushAction(EmailPushAction): + """ + UserPushAction instances include the necessary information to respond to + /notifications requests. + """ + + topological_ordering: int + highlight: bool + profile_tag: str + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class NotifCounts: + """ + The per-user, per-room count of notifications. Used by sync and push. + """ + + notify_count: int + unread_count: int + highlight_count: int + + +def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str: """Custom serializer for actions. This allows us to "compress" common actions. We use the fact that most users have the same actions for notifs (and for @@ -70,7 +108,7 @@ def _serialize_action(actions, is_highlight): return json_encoder.encode(actions) -def _deserialize_action(actions, is_highlight): +def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, str]]: """Custom deserializer for actions. This allows us to "compress" common actions""" if actions: return db_to_json(actions) @@ -82,12 +120,17 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn - self.stream_ordering_month_ago = None - self.stream_ordering_day_ago = None + self.stream_ordering_month_ago: Optional[int] = None + self.stream_ordering_day_ago: Optional[int] = None cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn") self._find_stream_orderings_for_times_txn(cur) @@ -111,7 +154,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): room_id: str, user_id: str, last_read_event_id: Optional[str], - ) -> Dict[str, int]: + ) -> NotifCounts: """Get the notification count, the highlight count and the unread message count for a given user in a given room after the given read receipt. @@ -140,15 +183,15 @@ class EventPushActionsWorkerStore(SQLBaseStore): def _get_unread_counts_by_receipt_txn( self, - txn, - room_id, - user_id, - last_read_event_id, - ): + txn: LoggingTransaction, + room_id: str, + user_id: str, + last_read_event_id: Optional[str], + ) -> NotifCounts: stream_ordering = None if last_read_event_id is not None: - stream_ordering = self.get_stream_id_for_event_txn( + stream_ordering = self.get_stream_id_for_event_txn( # type: ignore[attr-defined] txn, last_read_event_id, allow_none=True, @@ -166,13 +209,15 @@ class EventPushActionsWorkerStore(SQLBaseStore): retcol="event_id", ) - stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) + stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) # type: ignore[attr-defined] return self._get_unread_counts_by_pos_txn( txn, room_id, user_id, stream_ordering ) - def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): + def _get_unread_counts_by_pos_txn( + self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int + ) -> NotifCounts: sql = ( "SELECT" " COUNT(CASE WHEN notif = 1 THEN 1 END)," @@ -210,16 +255,16 @@ class EventPushActionsWorkerStore(SQLBaseStore): # for this row. unread_count += row[1] - return { - "notify_count": notif_count, - "unread_count": unread_count, - "highlight_count": highlight_count, - } + return NotifCounts( + notify_count=notif_count, + unread_count=unread_count, + highlight_count=highlight_count, + ) async def get_push_action_users_in_range( - self, min_stream_ordering, max_stream_ordering - ): - def f(txn): + self, min_stream_ordering: int, max_stream_ordering: int + ) -> List[str]: + def f(txn: LoggingTransaction) -> List[str]: sql = ( "SELECT DISTINCT(user_id) FROM event_push_actions WHERE" " stream_ordering >= ? AND stream_ordering <= ? AND notif = 1" @@ -227,8 +272,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn] - ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f) - return ret + return await self.db_pool.runInteraction("get_push_action_users_in_range", f) async def get_unread_push_actions_for_user_in_range_for_http( self, @@ -254,7 +298,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ # find rooms that have a read receipt in them and return the next # push actions - def get_after_receipt(txn): + def get_after_receipt( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, int, str, bool]]: # find rooms that have a read receipt in them and return the next # push actions sql = ( @@ -280,7 +326,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() + return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) after_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt @@ -289,7 +335,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): # There are rooms with push actions in them but you don't have a read receipt in # them e.g. rooms you've been invited to, so get push actions for rooms which do # not have read receipts in them too. - def get_no_receipt(txn): + def get_no_receipt( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, int, str, bool]]: sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," " ep.highlight " @@ -309,19 +357,19 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() + return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) no_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt ) notifs = [ - { - "event_id": row[0], - "room_id": row[1], - "stream_ordering": row[2], - "actions": _deserialize_action(row[3], row[4]), - } + HttpPushAction( + event_id=row[0], + room_id=row[1], + stream_ordering=row[2], + actions=_deserialize_action(row[3], row[4]), + ) for row in after_read_receipt + no_read_receipt ] @@ -329,7 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): # contain results from the first query, correctly ordered, followed # by results from the second query, but we want them all ordered # by stream_ordering, oldest first. - notifs.sort(key=lambda r: r["stream_ordering"]) + notifs.sort(key=lambda r: r.stream_ordering) # Take only up to the limit. We have to stop at the limit because # one of the subqueries may have hit the limit. @@ -359,7 +407,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ # find rooms that have a read receipt in them and return the most recent # push actions - def get_after_receipt(txn): + def get_after_receipt( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, int, str, bool, int]]: sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," " ep.highlight, e.received_ts" @@ -384,7 +434,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() + return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) after_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt @@ -393,7 +443,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): # There are rooms with push actions in them but you don't have a read receipt in # them e.g. rooms you've been invited to, so get push actions for rooms which do # not have read receipts in them too. - def get_no_receipt(txn): + def get_no_receipt( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, int, str, bool, int]]: sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," " ep.highlight, e.received_ts" @@ -413,7 +465,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() + return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) no_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt @@ -421,13 +473,13 @@ class EventPushActionsWorkerStore(SQLBaseStore): # Make a list of dicts from the two sets of results. notifs = [ - { - "event_id": row[0], - "room_id": row[1], - "stream_ordering": row[2], - "actions": _deserialize_action(row[3], row[4]), - "received_ts": row[5], - } + EmailPushAction( + event_id=row[0], + room_id=row[1], + stream_ordering=row[2], + actions=_deserialize_action(row[3], row[4]), + received_ts=row[5], + ) for row in after_read_receipt + no_read_receipt ] @@ -435,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): # contain results from the first query, correctly ordered, followed # by results from the second query, but we want them all ordered # by received_ts (most recent first) - notifs.sort(key=lambda r: -(r["received_ts"] or 0)) + notifs.sort(key=lambda r: -(r.received_ts or 0)) # Now return the first `limit` return notifs[:limit] @@ -456,7 +508,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): True if there may be push to process, False if there definitely isn't. """ - def _get_if_maybe_push_in_range_for_user_txn(txn): + def _get_if_maybe_push_in_range_for_user_txn(txn: LoggingTransaction) -> bool: sql = """ SELECT 1 FROM event_push_actions WHERE user_id = ? AND stream_ordering > ? AND notif = 1 @@ -490,19 +542,21 @@ class EventPushActionsWorkerStore(SQLBaseStore): # This is a helper function for generating the necessary tuple that # can be used to insert into the `event_push_actions_staging` table. - def _gen_entry(user_id, actions): + def _gen_entry( + user_id: str, actions: List[Union[dict, str]] + ) -> Tuple[str, str, str, int, int, int]: is_highlight = 1 if _action_has_highlight(actions) else 0 notif = 1 if "notify" in actions else 0 return ( event_id, # event_id column user_id, # user_id column - _serialize_action(actions, is_highlight), # actions column + _serialize_action(actions, bool(is_highlight)), # actions column notif, # notif column is_highlight, # highlight column int(count_as_unread), # unread column ) - def _add_push_actions_to_staging_txn(txn): + def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None: # We don't use simple_insert_many here to avoid the overhead # of generating lists of dicts. @@ -530,12 +584,11 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ try: - res = await self.db_pool.simple_delete( + await self.db_pool.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", ) - return res except Exception: # this method is called from an exception handler, so propagating # another exception here really isn't helpful - there's nothing @@ -588,7 +641,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) @staticmethod - def _find_first_stream_ordering_after_ts_txn(txn, ts): + def _find_first_stream_ordering_after_ts_txn( + txn: LoggingTransaction, ts: int + ) -> int: """ Find the stream_ordering of the first event that was received on or after a given timestamp. This is relatively slow as there is no index @@ -600,14 +655,14 @@ class EventPushActionsWorkerStore(SQLBaseStore): stream_ordering Args: - txn (twisted.enterprise.adbapi.Transaction): - ts (int): timestamp to search for + txn: + ts: timestamp to search for Returns: - int: stream ordering + The stream ordering """ txn.execute("SELECT MAX(stream_ordering) FROM events") - max_stream_ordering = txn.fetchone()[0] + max_stream_ordering = cast(Tuple[Optional[int]], txn.fetchone())[0] if max_stream_ordering is None: return 0 @@ -663,8 +718,10 @@ class EventPushActionsWorkerStore(SQLBaseStore): return range_end - async def get_time_of_last_push_action_before(self, stream_ordering): - def f(txn): + async def get_time_of_last_push_action_before( + self, stream_ordering: int + ) -> Optional[int]: + def f(txn: LoggingTransaction) -> Optional[Tuple[int]]: sql = ( "SELECT e.received_ts" " FROM event_push_actions AS ep" @@ -674,7 +731,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): " LIMIT 1" ) txn.execute(sql, (stream_ordering,)) - return txn.fetchone() + return cast(Optional[Tuple[int]], txn.fetchone()) result = await self.db_pool.runInteraction( "get_time_of_last_push_action_before", f @@ -682,7 +739,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): return result[0] if result else None @wrap_as_background_process("rotate_notifs") - async def _rotate_notifs(self): + async def _rotate_notifs(self) -> None: if self._doing_notif_rotation or self.stream_ordering_day_ago is None: return self._doing_notif_rotation = True @@ -700,7 +757,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): finally: self._doing_notif_rotation = False - def _rotate_notifs_txn(self, txn): + def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool: """Archives older notifications into event_push_summary. Returns whether the archiving process has caught up or not. """ @@ -725,6 +782,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): stream_row = txn.fetchone() if stream_row: (offset_stream_ordering,) = stream_row + assert self.stream_ordering_day_ago is not None rotate_to_stream_ordering = min( self.stream_ordering_day_ago, offset_stream_ordering ) @@ -740,7 +798,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): # We have caught up iff we were limited by `stream_ordering_day_ago` return caught_up - def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): + def _rotate_notifs_before_txn( + self, txn: LoggingTransaction, rotate_to_stream_ordering: int + ) -> None: old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", @@ -861,8 +921,8 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) def _remove_old_push_actions_before_txn( - self, txn, room_id, user_id, stream_ordering - ): + self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int + ) -> None: """ Purges old push actions for a user and room before a given stream_ordering. @@ -910,7 +970,12 @@ class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -929,9 +994,15 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) async def get_push_actions_for_user( - self, user_id, before=None, limit=50, only_highlight=False - ): - def f(txn): + self, + user_id: str, + before: Optional[str] = None, + limit: int = 50, + only_highlight: bool = False, + ) -> List[UserPushAction]: + def f( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, int, int, str, bool, str, int]]: before_clause = "" if before: before_clause = "AND epa.stream_ordering < ?" @@ -958,32 +1029,44 @@ class EventPushActionsStore(EventPushActionsWorkerStore): " LIMIT ?" % (before_clause,) ) txn.execute(sql, args) - return self.db_pool.cursor_to_dict(txn) + return cast( + List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall() + ) push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f) - for pa in push_actions: - pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) - return push_actions + return [ + UserPushAction( + event_id=row[0], + room_id=row[1], + stream_ordering=row[2], + actions=_deserialize_action(row[4], row[5]), + received_ts=row[7], + topological_ordering=row[3], + highlight=row[5], + profile_tag=row[6], + ) + for row in push_actions + ] -def _action_has_highlight(actions): +def _action_has_highlight(actions: List[Union[dict, str]]) -> bool: for action in actions: - try: - if action.get("set_tweak", None) == "highlight": - return action.get("value", True) - except AttributeError: - pass + if not isinstance(action, dict): + continue + + if action.get("set_tweak", None) == "highlight": + return action.get("value", True) return False -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class _EventPushSummary: """Summary of pending event push actions for a given user in a given room. Used in _rotate_notifs_before_txn to manipulate results from event_push_actions. """ - unread_count = attr.ib(type=int) - stream_ordering = attr.ib(type=int) - old_user_id = attr.ib(type=str) - notif_count = attr.ib(type=int) + unread_count: int + stream_ordering: int + old_user_id: str + notif_count: int diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 4e528612ea..dd255aefb9 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -19,6 +19,7 @@ from collections import OrderedDict from typing import ( TYPE_CHECKING, Any, + Collection, Dict, Generator, Iterable, @@ -40,10 +41,13 @@ from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry -from synapse.storage.types import Connection from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id @@ -94,7 +98,7 @@ class PersistEventsStore: hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore", - db_conn: Connection, + db_conn: LoggingDatabaseConnection, ): self.hs = hs self.db_pool = db @@ -1319,14 +1323,13 @@ class PersistEventsStore: return [ec for ec in events_and_contexts if ec[0] not in to_remove] - def _store_event_txn(self, txn, events_and_contexts): + def _store_event_txn( + self, + txn: LoggingTransaction, + events_and_contexts: Collection[Tuple[EventBase, EventContext]], + ) -> None: """Insert new events into the event, event_json, redaction and state_events tables. - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting """ if not events_and_contexts: @@ -1339,46 +1342,58 @@ class PersistEventsStore: d.pop("redacted_because", None) return d - self.db_pool.simple_insert_many_txn( + self.db_pool.simple_insert_many_values_txn( txn, table="event_json", - values=[ - { - "event_id": event.event_id, - "room_id": event.room_id, - "internal_metadata": json_encoder.encode( - event.internal_metadata.get_dict() - ), - "json": json_encoder.encode(event_dict(event)), - "format_version": event.format_version, - } + keys=("event_id", "room_id", "internal_metadata", "json", "format_version"), + values=( + ( + event.event_id, + event.room_id, + json_encoder.encode(event.internal_metadata.get_dict()), + json_encoder.encode(event_dict(event)), + event.format_version, + ) for event, _ in events_and_contexts - ], + ), ) - self.db_pool.simple_insert_many_txn( + self.db_pool.simple_insert_many_values_txn( txn, table="events", - values=[ - { - "instance_name": self._instance_name, - "stream_ordering": event.internal_metadata.stream_ordering, - "topological_ordering": event.depth, - "depth": event.depth, - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "processed": True, - "outlier": event.internal_metadata.is_outlier(), - "origin_server_ts": int(event.origin_server_ts), - "received_ts": self._clock.time_msec(), - "sender": event.sender, - "contains_url": ( - "url" in event.content and isinstance(event.content["url"], str) - ), - } + keys=( + "instance_name", + "stream_ordering", + "topological_ordering", + "depth", + "event_id", + "room_id", + "type", + "processed", + "outlier", + "origin_server_ts", + "received_ts", + "sender", + "contains_url", + ), + values=( + ( + self._instance_name, + event.internal_metadata.stream_ordering, + event.depth, # topological_ordering + event.depth, # depth + event.event_id, + event.room_id, + event.type, + True, # processed + event.internal_metadata.is_outlier(), + int(event.origin_server_ts), + self._clock.time_msec(), + event.sender, + "url" in event.content and isinstance(event.content["url"], str), + ) for event, _ in events_and_contexts - ], + ), ) # If we're persisting an unredacted event we go and ensure @@ -1397,27 +1412,15 @@ class PersistEventsStore: ) txn.execute(sql + clause, [False] + args) - state_events_and_contexts = [ - ec for ec in events_and_contexts if ec[0].is_state() - ] - - state_values = [] - for event, _ in state_events_and_contexts: - vals = { - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - } - - # TODO: How does this work with backfilling? - if hasattr(event, "replaces_state"): - vals["prev_state"] = event.replaces_state - - state_values.append(vals) - - self.db_pool.simple_insert_many_txn( - txn, table="state_events", values=state_values + self.db_pool.simple_insert_many_values_txn( + txn, + table="state_events", + keys=("event_id", "room_id", "type", "state_key"), + values=( + (event.event_id, event.room_id, event.type, event.state_key) + for event, _ in events_and_contexts + if event.is_state() + ), ) def _store_rejected_events_txn(self, txn, events_and_contexts): @@ -1780,10 +1783,14 @@ class PersistEventsStore: ) if rel_type == RelationTypes.REPLACE: - txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) + txn.call_after( + self.store.get_applicable_edit.invalidate, (parent_id, event.room_id) + ) if rel_type == RelationTypes.THREAD: - txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) + txn.call_after( + self.store.get_thread_summary.invalidate, (parent_id, event.room_id) + ) def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): """Handles keeping track of insertion events and edges/connections. @@ -1969,14 +1976,17 @@ class PersistEventsStore: txn, self.store.get_retention_policy_for_room, (event.room_id,) ) - def store_event_search_txn(self, txn, event, key, value): + def store_event_search_txn( + self, txn: LoggingTransaction, event: EventBase, key: str, value: str + ) -> None: """Add event to the search table Args: - txn (cursor): - event (EventBase): - key (str): - value (str): + txn: The database transaction. + event: The event being added to the search table. + key: A key describing the search value (one of "content.name", + "content.topic", or "content.body") + value: The value from the event's content. """ self.store.store_search_entries_txn( txn, diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index c88fd35e7f..a68f14ba48 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast import attr @@ -23,6 +23,7 @@ from synapse.events import make_event_from_dict from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_tuple_comparison_clause, ) @@ -83,7 +84,12 @@ class _CalculateChainCover: class EventsBackgroundUpdatesStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( @@ -234,12 +240,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ################################################################################ - async def _background_reindex_fields_sender(self, progress, batch_size): + async def _background_reindex_fields_sender( + self, progress: JsonDict, batch_size: int + ) -> int: target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - def reindex_txn(txn): + def reindex_txn(txn: LoggingTransaction) -> int: sql = ( "SELECT stream_ordering, event_id, json FROM events" " INNER JOIN event_json USING (event_id)" @@ -301,12 +309,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return result - async def _background_reindex_origin_server_ts(self, progress, batch_size): + async def _background_reindex_origin_server_ts( + self, progress: JsonDict, batch_size: int + ) -> int: target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - def reindex_search_txn(txn): + def reindex_search_txn(txn: LoggingTransaction) -> int: sql = ( "SELECT stream_ordering, event_id FROM events" " WHERE ? <= stream_ordering AND stream_ordering < ?" @@ -375,7 +385,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return result - async def _cleanup_extremities_bg_update(self, progress, batch_size): + async def _cleanup_extremities_bg_update( + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to clean out extremities that should have been deleted previously. @@ -396,12 +408,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): # have any descendants, but if they do then we should delete those # extremities. - def _cleanup_extremities_bg_update_txn(txn): + def _cleanup_extremities_bg_update_txn(txn: LoggingTransaction) -> int: # The set of extremity event IDs that we're checking this round original_set = set() - # A dict[str, set[str]] of event ID to their prev events. - graph = {} + # A dict[str, Set[str]] of event ID to their prev events. + graph: Dict[str, Set[str]] = {} # The set of descendants of the original set that are not rejected # nor soft-failed. Ancestors of these events should be removed @@ -530,7 +542,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): room_ids = {row["room_id"] for row in rows} for room_id in room_ids: txn.call_after( - self.get_latest_event_ids_in_room.invalidate, (room_id,) + self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined] ) self.db_pool.simple_delete_many_txn( @@ -552,7 +564,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES ) - def _drop_table_txn(txn): + def _drop_table_txn(txn: LoggingTransaction) -> None: txn.execute("DROP TABLE _extremities_to_check") await self.db_pool.runInteraction( @@ -561,11 +573,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return num_handled - async def _redactions_received_ts(self, progress, batch_size): + async def _redactions_received_ts(self, progress: JsonDict, batch_size: int) -> int: """Handles filling out the `received_ts` column in redactions.""" last_event_id = progress.get("last_event_id", "") - def _redactions_received_ts_txn(txn): + def _redactions_received_ts_txn(txn: LoggingTransaction) -> int: # Fetch the set of event IDs that we want to update sql = """ SELECT event_id FROM redactions @@ -616,10 +628,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return count - async def _event_fix_redactions_bytes(self, progress, batch_size): + async def _event_fix_redactions_bytes( + self, progress: JsonDict, batch_size: int + ) -> int: """Undoes hex encoded censored redacted event JSON.""" - def _event_fix_redactions_bytes_txn(txn): + def _event_fix_redactions_bytes_txn(txn: LoggingTransaction) -> None: # This update is quite fast due to new index. txn.execute( """ @@ -644,11 +658,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return 1 - async def _event_store_labels(self, progress, batch_size): + async def _event_store_labels(self, progress: JsonDict, batch_size: int) -> int: """Background update handler which will store labels for existing events.""" last_event_id = progress.get("last_event_id", "") - def _event_store_labels_txn(txn): + def _event_store_labels_txn(txn: LoggingTransaction) -> int: txn.execute( """ SELECT event_id, json FROM event_json @@ -748,7 +762,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ), ) - return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore + return cast( + List[Tuple[str, str, JsonDict, bool, bool]], + [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn], + ) results = await self.db_pool.runInteraction( desc="_rejected_events_metadata_get", func=get_rejected_events @@ -906,7 +923,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): def _calculate_chain_cover_txn( self, - txn: Cursor, + txn: LoggingTransaction, last_room_id: str, last_depth: int, last_stream: int, @@ -1017,10 +1034,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): PersistEventsStore._add_chain_cover_index( txn, self.db_pool, - self.event_chain_id_gen, + self.event_chain_id_gen, # type: ignore[attr-defined] event_to_room_id, event_to_types, - event_to_auth_chain, + cast(Dict[str, Sequence[str]], event_to_auth_chain), ) return _CalculateChainCover( @@ -1040,7 +1057,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): """ current_event_id = progress.get("current_event_id", "") - def purged_chain_cover_txn(txn) -> int: + def purged_chain_cover_txn(txn: LoggingTransaction) -> int: # The event ID from events will be null if the chain ID / sequence # number points to a purged event. sql = """ @@ -1175,14 +1192,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): # Iterate the parent IDs and invalidate caches. for parent_id in {r[1] for r in relations_to_insert}: cache_tuple = (parent_id,) - self._invalidate_cache_and_stream( - txn, self.get_relations_for_event, cache_tuple + self._invalidate_cache_and_stream( # type: ignore[attr-defined] + txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined] ) - self._invalidate_cache_and_stream( - txn, self.get_aggregation_groups_for_event, cache_tuple + self._invalidate_cache_and_stream( # type: ignore[attr-defined] + txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined] ) - self._invalidate_cache_and_stream( - txn, self.get_thread_summary, cache_tuple + self._invalidate_cache_and_stream( # type: ignore[attr-defined] + txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined] ) if results: @@ -1214,7 +1231,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): """ batch_size = max(batch_size, 1) - def process(txn: Cursor) -> int: + def process(txn: LoggingTransaction) -> int: last_stream = progress.get("last_stream", -(1 << 31)) txn.execute( """ diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c7b660ac5a..8d4287045a 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -1383,10 +1383,6 @@ class EventsWorkerStore(SQLBaseStore): return {"v1": complexity_v1} - def get_current_events_token(self) -> int: - """The current maximum token that events have reached""" - return self._stream_id_gen.get_current_token() - async def get_all_new_forward_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index cf842803bc..cb9ee08fa8 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Optional, Tuple, Union, cast from canonicaljson import encode_canonical_json @@ -63,7 +63,7 @@ class FilteringStore(SQLBaseStore): sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" txn.execute(sql, (user_localpart,)) - max_id = txn.fetchone()[0] # type: ignore[index] + max_id = cast(Tuple[Optional[int]], txn.fetchone())[0] if max_id is None: filter_id = 0 else: diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index bb621df0dd..3f6086050b 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py
@@ -19,8 +19,7 @@ from typing_extensions import TypedDict from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool -from synapse.storage.types import Connection +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.types import JsonDict from synapse.util import json_encoder @@ -40,7 +39,12 @@ class _RoomInGroup(TypedDict): class GroupServerWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): database.updates.register_background_index_update( update_name="local_group_updates_index", index_name="local_group_updates_stream_id_index", diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index a540f7fb26..bedacaf0d7 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py
@@ -20,8 +20,11 @@ from twisted.internet.interfaces import IReactorCore from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingTransaction -from synapse.storage.types import Connection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.util import Clock from synapse.util.stringutils import random_string @@ -54,7 +57,12 @@ class LockStore(SQLBaseStore): `last_renewed_ts` column with the current time. """ - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._reactor = hs.get_reactor() diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 1b076683f7..cbba356b4a 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py
@@ -23,6 +23,7 @@ from typing import ( Optional, Tuple, Union, + cast, ) from synapse.storage._base import SQLBaseStore @@ -220,7 +221,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): WHERE user_id = ? """ txn.execute(sql, args) - count = txn.fetchone()[0] # type: ignore[index] + count = cast(Tuple[int], txn.fetchone())[0] sql = """ SELECT diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index d901933ae4..1480a0f048 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Dict from synapse.metrics import GaugeBucketCollector from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) @@ -55,7 +55,12 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): stats and prometheus metrics. """ - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Read the extrems every 60 minutes @@ -100,7 +105,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): def _count_messages(txn): sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events + SELECT COUNT(*) FROM events WHERE type = 'm.room.encrypted' AND stream_ordering > ? """ @@ -117,7 +122,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): like_clause = "%:" + self.hs.hostname sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events + SELECT COUNT(*) FROM events WHERE type = 'm.room.encrypted' AND sender LIKE ? AND stream_ordering > ? @@ -134,7 +139,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): async def count_daily_active_e2ee_rooms(self): def _count(txn): sql = """ - SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events + SELECT COUNT(DISTINCT room_id) FROM events WHERE type = 'm.room.encrypted' AND stream_ordering > ? """ @@ -156,7 +161,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): def _count_messages(txn): sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events + SELECT COUNT(*) FROM events WHERE type = 'm.room.message' AND stream_ordering > ? """ @@ -173,7 +178,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): like_clause = "%:" + self.hs.hostname sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events + SELECT COUNT(*) FROM events WHERE type = 'm.room.message' AND sender LIKE ? AND stream_ordering > ? @@ -190,7 +195,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): async def count_daily_active_rooms(self): def _count(txn): sql = """ - SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events + SELECT COUNT(DISTINCT room_id) FROM events WHERE type = 'm.room.message' AND stream_ordering > ? """ @@ -226,7 +231,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): Returns number of users seen in the past time_from period """ sql = """ - SELECT COALESCE(count(*), 0) FROM ( + SELECT COUNT(*) FROM ( SELECT user_id FROM user_ips WHERE last_seen > ? GROUP BY user_id @@ -253,7 +258,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): thirty_days_ago_in_secs = now - thirty_days_in_secs sql = """ - SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT platform, COUNT(*) FROM ( SELECT users.name, platform, users.creation_ts * 1000, MAX(uip.last_seen) @@ -291,7 +296,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): results[row[0]] = row[1] sql = """ - SELECT COALESCE(count(*), 0) FROM ( + SELECT COUNT(*) FROM ( SELECT users.name, users.creation_ts * 1000, MAX(uip.last_seen) FROM users diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index b5284e4f67..8f09dd8e87 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -16,8 +16,13 @@ from typing import TYPE_CHECKING, Dict, List, Optional from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + make_in_list_sql_clause, +) from synapse.util.caches.descriptors import cached +from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: from synapse.server import HomeServer @@ -30,7 +35,12 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs @@ -49,7 +59,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): def _count_users(txn): # Exclude app service users sql = """ - SELECT COALESCE(count(*), 0) + SELECT COUNT(*) FROM monthly_active_users LEFT JOIN users ON monthly_active_users.user_id=users.name @@ -76,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): def _count_users_by_service(txn): sql = """ - SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0) + SELECT COALESCE(appservice_id, 'native'), COUNT(*) FROM monthly_active_users LEFT JOIN users ON monthly_active_users.user_id=users.name GROUP BY appservice_id; @@ -103,7 +113,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): : self.hs.config.server.max_mau_value ]: user_id = await self.hs.get_datastore().get_user_id_by_threepid( - tp["medium"], tp["address"] + tp["medium"], canonicalise_email(tp["address"]) ) if user_id: users.append(user_id) @@ -212,7 +222,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._mau_stats_only = hs.config.server.mau_stats_only diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index cc0eebdb46..cbf9ec38f7 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator @@ -33,7 +33,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) @@ -52,7 +52,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) @@ -269,6 +269,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): """ # Add user entries to the table, updating the presence_stream_id column if the user already # exists in the table. + presence_stream_id = self._presence_id_gen.get_current_token() await self.db_pool.simple_upsert_many( table="users_to_send_full_presence_to", key_names=("user_id",), @@ -279,9 +280,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): # devices at different times, each device will receive full presence once - when # the presence stream ID in their sync token is less than the one in the table # for their user ID. - value_values=( - (self._presence_id_gen.get_current_token(),) for _ in user_ids - ), + value_values=[(presence_stream_id,) for _ in user_ids], desc="add_users_to_send_full_presence_to", ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 3b63267395..e01c94930a 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError, StoreError from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.pusher import PusherWorkerStore @@ -81,7 +81,12 @@ class PushRulesWorkerStore( `get_max_push_rules_stream_id` which can be called in the initializer. """ - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b73ce53c91..747b4f31df 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py
@@ -22,7 +22,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder -from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.descriptors import cached if TYPE_CHECKING: from synapse.server import HomeServer @@ -196,27 +196,6 @@ class PusherWorkerStore(SQLBaseStore): # This only exists for the cachedList decorator raise NotImplementedError() - @cachedList( - cached_method_name="get_if_user_has_pusher", - list_name="user_ids", - num_args=1, - ) - async def get_if_users_have_pushers( - self, user_ids: Iterable[str] - ) -> Dict[str, bool]: - rows = await self.db_pool.simple_select_many_batch( - table="pushers", - column="user_name", - iterable=user_ids, - retcols=["user_name"], - desc="get_if_users_have_pushers", - ) - - result = {user_id: False for user_id in user_ids} - result.update({r["user_name"]: True for r in rows}) - - return result - async def update_pusher_last_stream_ordering( self, app_id, pushkey, user_id, last_stream_ordering ) -> None: @@ -515,7 +494,7 @@ class PusherStore(PusherWorkerStore): # invalidate, since we the user might not have had a pusher before await self.db_pool.runInteraction( "add_pusher", - self._invalidate_cache_and_stream, # type: ignore + self._invalidate_cache_and_stream, # type: ignore[attr-defined] self.get_if_user_has_pusher, (user_id,), ) @@ -524,7 +503,7 @@ class PusherStore(PusherWorkerStore): self, app_id: str, pushkey: str, user_id: str ) -> None: def delete_pusher_txn(txn, stream_id): - self._invalidate_cache_and_stream( # type: ignore + self._invalidate_cache_and_stream( # type: ignore[attr-defined] txn, self.get_if_user_has_pusher, (user_id,) ) @@ -569,7 +548,7 @@ class PusherStore(PusherWorkerStore): pushers = list(await self.get_pushers_by_user_id(user_id)) def delete_pushers_txn(txn, stream_ids): - self._invalidate_cache_and_stream( # type: ignore + self._invalidate_cache_and_stream( # type: ignore[attr-defined] txn, self.get_if_user_has_pusher, (user_id,) ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c99f8aebdb..bf0b903af2 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -14,14 +14,29 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, +) from twisted.internet import defer +from synapse.api.constants import ReceiptTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict @@ -36,7 +51,12 @@ logger = logging.getLogger(__name__) class ReceiptsWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self._instance_name = hs.get_instance_name() if isinstance(database.engine, PostgresEngine): @@ -78,17 +98,13 @@ class ReceiptsWorkerStore(SQLBaseStore): "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) - def get_max_receipt_stream_id(self): - """Get the current max stream ID for receipts stream - - Returns: - int - """ + def get_max_receipt_stream_id(self) -> int: + """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() @cached() - async def get_users_with_read_receipts_in_room(self, room_id): - receipts = await self.get_receipts_for_room(room_id, "m.read") + async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]: + receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ) return {r["user_id"] for r in receipts} @cached(num_args=2) @@ -119,7 +135,9 @@ class ReceiptsWorkerStore(SQLBaseStore): ) @cached(num_args=2) - async def get_receipts_for_user(self, user_id, receipt_type): + async def get_receipts_for_user( + self, user_id: str, receipt_type: str + ) -> Dict[str, str]: rows = await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, @@ -129,8 +147,10 @@ class ReceiptsWorkerStore(SQLBaseStore): return {row["room_id"]: row["event_id"] for row in rows} - async def get_receipts_for_user_with_orderings(self, user_id, receipt_type): - def f(txn): + async def get_receipts_for_user_with_orderings( + self, user_id: str, receipt_type: str + ) -> JsonDict: + def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]: sql = ( "SELECT rl.room_id, rl.event_id," " e.topological_ordering, e.stream_ordering" @@ -209,10 +229,10 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=3, tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> List[dict]: + ) -> List[JsonDict]: """See get_linearized_receipts_for_room""" - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = ( "SELECT * FROM receipts_linearized WHERE" @@ -250,11 +270,13 @@ class ReceiptsWorkerStore(SQLBaseStore): list_name="room_ids", num_args=3, ) - async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + async def _get_linearized_receipts_for_rooms( + self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None + ) -> Dict[str, List[JsonDict]]: if not room_ids: return {} - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE @@ -323,7 +345,7 @@ class ReceiptsWorkerStore(SQLBaseStore): A dictionary of roomids to a list of receipts. """ - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE @@ -379,7 +401,7 @@ class ReceiptsWorkerStore(SQLBaseStore): if last_id == current_id: return defer.succeed([]) - def _get_users_sent_receipts_between_txn(txn): + def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]: sql = """ SELECT DISTINCT user_id FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? @@ -419,7 +441,9 @@ class ReceiptsWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_updated_receipts_txn(txn): + def get_all_updated_receipts_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, list]], int, bool]: sql = """ SELECT stream_id, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized @@ -446,8 +470,8 @@ class ReceiptsWorkerStore(SQLBaseStore): def _invalidate_get_users_with_receipts_in_room( self, room_id: str, receipt_type: str, user_id: str - ): - if receipt_type != "m.read": + ) -> None: + if receipt_type != ReceiptTypes.READ: return res = self.get_users_with_read_receipts_in_room.cache.get_immediate( @@ -461,7 +485,9 @@ class ReceiptsWorkerStore(SQLBaseStore): self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): + def invalidate_caches_for_receipt( + self, room_id: str, receipt_type: str, user_id: str + ) -> None: self.get_receipts_for_user.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( @@ -482,11 +508,18 @@ class ReceiptsWorkerStore(SQLBaseStore): return super().process_replication_rows(stream_name, instance_name, token, rows) def insert_linearized_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_id, data, stream_id - ): + self, + txn: LoggingTransaction, + room_id: str, + receipt_type: str, + user_id: str, + event_id: str, + data: JsonDict, + stream_id: int, + ) -> Optional[int]: """Inserts a read-receipt into the database if it's newer than the current RR - Returns: int|None + Returns: None if the RR is older than the current RR otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) @@ -550,7 +583,7 @@ class ReceiptsWorkerStore(SQLBaseStore): lock=False, ) - if receipt_type == "m.read" and stream_ordering is not None: + if receipt_type == ReceiptTypes.READ and stream_ordering is not None: self._remove_old_push_actions_before_txn( txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering ) @@ -580,7 +613,7 @@ class ReceiptsWorkerStore(SQLBaseStore): else: # we need to points in graph -> linearized form. # TODO: Make this better. - def graph_to_linear(txn): + def graph_to_linear(txn: LoggingTransaction) -> str: clause, args = make_in_list_sql_clause( self.database_engine, "event_id", event_ids ) @@ -634,11 +667,16 @@ class ReceiptsWorkerStore(SQLBaseStore): return stream_id, max_persisted_id async def insert_graph_receipt( - self, room_id, receipt_type, user_id, event_ids, data - ): + self, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: JsonDict, + ) -> None: assert self._can_write_to_receipts - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, @@ -649,8 +687,14 @@ class ReceiptsWorkerStore(SQLBaseStore): ) def insert_graph_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_ids, data - ): + self, + txn: LoggingTransaction, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: JsonDict, + ) -> None: assert self._can_write_to_receipts txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e1ddf06916..4175c82a25 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@ import logging import random import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast import attr @@ -794,7 +794,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): yesterday = int(self._clock.time()) - (60 * 60 * 24) sql = """ - SELECT user_type, COALESCE(count(*), 0) AS count FROM ( + SELECT user_type, COUNT(*) AS count FROM ( SELECT CASE WHEN is_guest=0 AND appservice_id IS NULL THEN 'native' @@ -819,7 +819,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): def _count_users(txn): txn.execute( """ - SELECT COALESCE(COUNT(*), 0) FROM users + SELECT COUNT(*) FROM users WHERE appservice_id IS NULL """ ) @@ -856,7 +856,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Args: medium: threepid medium e.g. email - address: threepid address e.g. me@example.com + address: threepid address e.g. me@example.com. This must already be + in canonical form. Returns: The user ID or None if no user id/threepid mapping exists @@ -1356,12 +1357,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # Override type because the return type is only optional if # allow_none is True, and we don't want mypy throwing errors # about None not being indexable. - res: Dict[str, Any] = self.db_pool.simple_select_one_txn( - txn, - "registration_tokens", - keyvalues={"token": token}, - retcols=["pending", "completed"], - ) # type: ignore + res = cast( + Dict[str, Any], + self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ), + ) # Decrement pending and increment completed self.db_pool.simple_update_one_txn( diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0a43acda07..4ff6aed253 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py
@@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, cast import attr @@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore): async def get_relations_for_event( self, event_id: str, + room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, aggregation_key: Optional[str] = None, @@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. aggregation_key: Only fetch events with this aggregation key, if given. @@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore): the form `{"event_id": "..."}`. """ - where_clause = ["relates_to_id = ?"] - where_args: List[Union[str, int]] = [event_id] + where_clause = ["relates_to_id = ?", "room_id = ?"] + where_args: List[Union[str, int]] = [event_id, room_id] if relation_type is not None: where_clause.append("relation_type = ?") @@ -199,6 +201,7 @@ class RelationsWorkerStore(SQLBaseStore): async def get_aggregation_groups_for_event( self, event_id: str, + room_id: str, event_type: Optional[str] = None, limit: int = 5, direction: str = "b", @@ -213,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. event_type: Only fetch events with this event type, if given. limit: Only fetch the `limit` groups. direction: Whether to fetch the highest count first (`"b"`) or @@ -225,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore): `type`, `key` and `count` fields. """ - where_clause = ["relates_to_id = ?", "relation_type = ?"] - where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION] + where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"] + where_args: List[Union[str, int]] = [ + event_id, + room_id, + RelationTypes.ANNOTATION, + ] if event_type: where_clause.append("type = ?") @@ -288,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore): ) @cached() - async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: + async def get_applicable_edit( + self, event_id: str, room_id: str + ) -> Optional[EventBase]: """Get the most recent edit (if any) that has happened for the given event. @@ -296,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: The original event ID + room_id: The original event's room ID Returns: The most recent edit, if any. @@ -317,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore): WHERE relates_to_id = ? AND relation_type = ? + AND edit.room_id = ? AND edit.type = 'm.room.message' ORDER by edit.origin_server_ts DESC, edit.event_id DESC LIMIT 1 """ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]: - txn.execute(sql, (event_id, RelationTypes.REPLACE)) + txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id)) row = txn.fetchone() if row: return row[0] @@ -340,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore): @cached() async def get_thread_summary( - self, event_id: str + self, event_id: str, room_id: str ) -> Tuple[int, Optional[EventBase]]: """Get the number of threaded replies, the senders of those replies, and the latest reply (if any) for the given event. Args: - event_id: The original event ID + event_id: Summarize the thread related to this event ID. + room_id: The room the event belongs to. Returns: The number of items in the thread and the most recent response, if any. @@ -363,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore): INNER JOIN events USING (event_id) WHERE relates_to_id = ? + AND room_id = ? AND relation_type = ? ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT 1 """ - txn.execute(sql, (event_id, RelationTypes.THREAD)) + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) row = txn.fetchone() if row is None: return 0, None @@ -376,14 +390,16 @@ class RelationsWorkerStore(SQLBaseStore): latest_event_id = row[0] sql = """ - SELECT COALESCE(COUNT(event_id), 0) + SELECT COUNT(event_id) FROM event_relations + INNER JOIN events USING (event_id) WHERE relates_to_id = ? + AND room_id = ? AND relation_type = ? """ - txn.execute(sql, (event_id, RelationTypes.THREAD)) - count = txn.fetchone()[0] # type: ignore[index] + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) + count = cast(Tuple[int], txn.fetchone())[0] return count, latest_event_id diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 7d694d852d..c0e837854a 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -13,20 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import logging from abc import abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) + +import attr from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction -from synapse.storage.databases.main.search import SearchStore +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import IdGenerator from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -38,9 +54,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -RatelimitOverride = collections.namedtuple( - "RatelimitOverride", ("messages_per_second", "burst_count") -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RatelimitOverride: + messages_per_second: int + burst_count: int class RoomSortOrder(Enum): @@ -71,8 +88,13 @@ class RoomSortOrder(Enum): STATE_EVENTS = "state_events" -class RoomWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): +class RoomWorkerStore(CacheInvalidationWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.config = hs.config @@ -83,7 +105,7 @@ class RoomWorkerStore(SQLBaseStore): room_creator_user_id: str, is_public: bool, room_version: RoomVersion, - ): + ) -> None: """Stores a room. Args: @@ -111,7 +133,7 @@ class RoomWorkerStore(SQLBaseStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - async def get_room(self, room_id: str) -> dict: + async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]: """Retrieve a room. Args: @@ -136,7 +158,9 @@ class RoomWorkerStore(SQLBaseStore): A dict containing the room information, or None if the room is unknown. """ - def get_room_with_stats_txn(txn, room_id): + def get_room_with_stats_txn( + txn: LoggingTransaction, room_id: str + ) -> Optional[Dict[str, Any]]: sql = """ SELECT room_id, state.name, state.canonical_alias, curr.joined_members, curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, @@ -185,7 +209,7 @@ class RoomWorkerStore(SQLBaseStore): ignore_non_federatable: If true filters out non-federatable rooms """ - def _count_public_rooms_txn(txn): + def _count_public_rooms_txn(txn: LoggingTransaction) -> int: query_args = [] if network_tuple: @@ -195,6 +219,7 @@ class RoomWorkerStore(SQLBaseStore): WHERE appservice_id = ? AND network_id = ? """ query_args.append(network_tuple.appservice_id) + assert network_tuple.network_id is not None query_args.append(network_tuple.network_id) else: published_sql = """ @@ -208,7 +233,7 @@ class RoomWorkerStore(SQLBaseStore): sql = """ SELECT - COALESCE(COUNT(*), 0) + COUNT(*) FROM ( %(published_sql)s ) published @@ -226,7 +251,7 @@ class RoomWorkerStore(SQLBaseStore): } txn.execute(sql, query_args) - return txn.fetchone()[0] + return cast(Tuple[int], txn.fetchone())[0] return await self.db_pool.runInteraction( "count_public_rooms", _count_public_rooms_txn @@ -235,11 +260,11 @@ class RoomWorkerStore(SQLBaseStore): async def get_room_count(self) -> int: """Retrieve the total number of rooms.""" - def f(txn): + def f(txn: LoggingTransaction) -> int: sql = "SELECT count(*) FROM rooms" txn.execute(sql) - row = txn.fetchone() - return row[0] or 0 + row = cast(Tuple[int], txn.fetchone()) + return row[0] return await self.db_pool.runInteraction("get_rooms", f) @@ -251,7 +276,7 @@ class RoomWorkerStore(SQLBaseStore): bounds: Optional[Tuple[int, str]], forwards: bool, ignore_non_federatable: bool = False, - ): + ) -> List[Dict[str, Any]]: """Gets the largest public rooms (where largest is in terms of joined members, as tracked in the statistics table). @@ -272,7 +297,7 @@ class RoomWorkerStore(SQLBaseStore): """ where_clauses = [] - query_args = [] + query_args: List[Union[str, int]] = [] if network_tuple: if network_tuple.appservice_id: @@ -281,6 +306,7 @@ class RoomWorkerStore(SQLBaseStore): WHERE appservice_id = ? AND network_id = ? """ query_args.append(network_tuple.appservice_id) + assert network_tuple.network_id is not None query_args.append(network_tuple.network_id) else: published_sql = """ @@ -372,7 +398,9 @@ class RoomWorkerStore(SQLBaseStore): LIMIT ? """ - def _get_largest_public_rooms_txn(txn): + def _get_largest_public_rooms_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, Any]]: txn.execute(sql, query_args) results = self.db_pool.cursor_to_dict(txn) @@ -435,7 +463,7 @@ class RoomWorkerStore(SQLBaseStore): """ # Filter room names by a string where_statement = "" - search_pattern = [] + search_pattern: List[object] = [] if search_term: where_statement = """ WHERE LOWER(state.name) LIKE ? @@ -543,7 +571,9 @@ class RoomWorkerStore(SQLBaseStore): where_statement, ) - def _get_rooms_paginate_txn(txn): + def _get_rooms_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], int]: # Add the search term into the WHERE clause # and execute the data query txn.execute(info_sql, search_pattern + [limit, start]) @@ -575,7 +605,7 @@ class RoomWorkerStore(SQLBaseStore): # Add the search term into the WHERE clause if present txn.execute(count_sql, search_pattern) - room_count = txn.fetchone() + room_count = cast(Tuple[int], txn.fetchone()) return rooms, room_count[0] return await self.db_pool.runInteraction( @@ -620,7 +650,7 @@ class RoomWorkerStore(SQLBaseStore): burst_count: How many actions that can be performed before being limited. """ - def set_ratelimit_txn(txn): + def set_ratelimit_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_upsert_txn( txn, table="ratelimit_override", @@ -643,7 +673,7 @@ class RoomWorkerStore(SQLBaseStore): user_id: user ID of the user """ - def delete_ratelimit_txn(txn): + def delete_ratelimit_txn(txn: LoggingTransaction) -> None: row = self.db_pool.simple_select_one_txn( txn, table="ratelimit_override", @@ -667,7 +697,7 @@ class RoomWorkerStore(SQLBaseStore): await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn) @cached() - async def get_retention_policy_for_room(self, room_id): + async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]: """Get the retention policy for a given room. If no retention policy has been found for this room, returns a policy defined @@ -676,13 +706,15 @@ class RoomWorkerStore(SQLBaseStore): configuration). Args: - room_id (str): The ID of the room to get the retention policy of. + room_id: The ID of the room to get the retention policy of. Returns: - dict[int, int]: "min_lifetime" and "max_lifetime" for this room. + A dict containing "min_lifetime" and "max_lifetime" for this room. """ - def get_retention_policy_for_room_txn(txn): + def get_retention_policy_for_room_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, Optional[int]]]: txn.execute( """ SELECT min_lifetime, max_lifetime FROM room_retention @@ -707,19 +739,23 @@ class RoomWorkerStore(SQLBaseStore): "max_lifetime": self.config.retention.retention_default_max_lifetime, } - row = ret[0] + min_lifetime = ret[0]["min_lifetime"] + max_lifetime = ret[0]["max_lifetime"] # If one of the room's policy's attributes isn't defined, use the matching # attribute from the default policy. # The default values will be None if no default policy has been defined, or if one # of the attributes is missing from the default policy. - if row["min_lifetime"] is None: - row["min_lifetime"] = self.config.retention.retention_default_min_lifetime + if min_lifetime is None: + min_lifetime = self.config.retention.retention_default_min_lifetime - if row["max_lifetime"] is None: - row["max_lifetime"] = self.config.retention.retention_default_max_lifetime + if max_lifetime is None: + max_lifetime = self.config.retention.retention_default_max_lifetime - return row + return { + "min_lifetime": min_lifetime, + "max_lifetime": max_lifetime, + } async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: """Retrieves all the local and remote media MXC URIs in a given room @@ -731,7 +767,9 @@ class RoomWorkerStore(SQLBaseStore): The local and remote media as a lists of the media IDs. """ - def _get_media_mxcs_in_room_txn(txn): + def _get_media_mxcs_in_room_txn( + txn: LoggingTransaction, + ) -> Tuple[List[str], List[str]]: local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) local_media_mxcs = [] remote_media_mxcs = [] @@ -757,7 +795,7 @@ class RoomWorkerStore(SQLBaseStore): logger.info("Quarantining media in room: %s", room_id) - def _quarantine_media_in_room_txn(txn): + def _quarantine_media_in_room_txn(txn: LoggingTransaction) -> int: local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) return self._quarantine_media_txn( txn, local_mxcs, remote_mxcs, quarantined_by @@ -767,13 +805,11 @@ class RoomWorkerStore(SQLBaseStore): "quarantine_media_in_room", _quarantine_media_in_room_txn ) - def _get_media_mxcs_in_room_txn(self, txn, room_id): + def _get_media_mxcs_in_room_txn( + self, txn: LoggingTransaction, room_id: str + ) -> Tuple[List[str], List[Tuple[str, str]]]: """Retrieves all the local and remote media MXC URIs in a given room - Args: - txn (cursor) - room_id (str) - Returns: The local and remote media as a lists of tuples where the key is the hostname and the value is the media ID. @@ -841,7 +877,7 @@ class RoomWorkerStore(SQLBaseStore): logger.info("Quarantining media: %s/%s", server_name, media_id) is_local = server_name == self.config.server.server_name - def _quarantine_media_by_id_txn(txn): + def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int: local_mxcs = [media_id] if is_local else [] remote_mxcs = [(server_name, media_id)] if not is_local else [] @@ -863,7 +899,7 @@ class RoomWorkerStore(SQLBaseStore): quarantined_by: The ID of the user who made the quarantine request """ - def _quarantine_media_by_user_txn(txn): + def _quarantine_media_by_user_txn(txn: LoggingTransaction) -> int: local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) @@ -871,7 +907,9 @@ class RoomWorkerStore(SQLBaseStore): "quarantine_media_by_user", _quarantine_media_by_user_txn ) - def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True): + def _get_media_ids_by_user_txn( + self, txn: LoggingTransaction, user_id: str, filter_quarantined: bool = True + ) -> List[str]: """Retrieves local media IDs by a given user Args: @@ -900,7 +938,7 @@ class RoomWorkerStore(SQLBaseStore): def _quarantine_media_txn( self, - txn, + txn: LoggingTransaction, local_mxcs: List[str], remote_mxcs: List[Tuple[str, str]], quarantined_by: Optional[str], @@ -928,12 +966,15 @@ class RoomWorkerStore(SQLBaseStore): # set quarantine if quarantined_by is not None: sql += "AND safe_from_quarantine = ?" - rows = [(quarantined_by, media_id, False) for media_id in local_mxcs] + txn.executemany( + sql, [(quarantined_by, media_id, False) for media_id in local_mxcs] + ) # remove from quarantine else: - rows = [(quarantined_by, media_id) for media_id in local_mxcs] + txn.executemany( + sql, [(quarantined_by, media_id) for media_id in local_mxcs] + ) - txn.executemany(sql, rows) # Note that a rowcount of -1 can be used to indicate no rows were affected. total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 @@ -951,7 +992,7 @@ class RoomWorkerStore(SQLBaseStore): async def get_rooms_for_retention_period_in_range( self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False - ) -> Dict[str, dict]: + ) -> Dict[str, Dict[str, Optional[int]]]: """Retrieves all of the rooms within the given retention range. Optionally includes the rooms which don't have a retention policy. @@ -971,7 +1012,9 @@ class RoomWorkerStore(SQLBaseStore): "min_lifetime" (int|None), and "max_lifetime" (int|None). """ - def get_rooms_for_retention_period_in_range_txn(txn): + def get_rooms_for_retention_period_in_range_txn( + txn: LoggingTransaction, + ) -> Dict[str, Dict[str, Optional[int]]]: range_conditions = [] args = [] @@ -1050,11 +1093,14 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = ( class RoomBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) - self.config = hs.config - self.db_pool.updates.register_background_update_handler( "insert_room_retention", self._background_insert_retention, @@ -1085,7 +1131,9 @@ class RoomBackgroundUpdateStore(SQLBaseStore): self._background_populate_rooms_creator_column, ) - async def _background_insert_retention(self, progress, batch_size): + async def _background_insert_retention( + self, progress: JsonDict, batch_size: int + ) -> int: """Retrieves a list of all rooms within a range and inserts an entry for each of them into the room_retention table. NULLs the property's columns if missing from the retention event in the room's @@ -1095,7 +1143,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): last_room = progress.get("room_id", "") - def _background_insert_retention_txn(txn): + def _background_insert_retention_txn(txn: LoggingTransaction) -> bool: txn.execute( """ SELECT state.room_id, state.event_id, events.json @@ -1154,15 +1202,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return batch_size async def _background_add_rooms_room_version_column( - self, progress: dict, batch_size: int - ): + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to go and add room version information to `rooms` table from `current_state_events` table. """ last_room_id = progress.get("room_id", "") - def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction): + def _background_add_rooms_room_version_column_txn( + txn: LoggingTransaction, + ) -> bool: sql = """ SELECT room_id, json FROM current_state_events INNER JOIN event_json USING (room_id, event_id) @@ -1223,7 +1273,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return batch_size async def _remove_tombstoned_rooms_from_directory( - self, progress, batch_size + self, progress: JsonDict, batch_size: int ) -> int: """Removes any rooms with tombstone events from the room directory @@ -1233,7 +1283,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): last_room = progress.get("room_id", "") - def _get_rooms(txn): + def _get_rooms(txn: LoggingTransaction) -> List[str]: txn.execute( """ SELECT room_id @@ -1271,7 +1321,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return len(rooms) @abstractmethod - def set_room_is_public(self, room_id, is_public): + def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]: # this will need to be implemented if a background update is performed with # existing (tombstoned, public) rooms in the database. # @@ -1318,7 +1368,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): 32-bit integer field. """ - def process(txn: Cursor) -> int: + def process(txn: LoggingTransaction) -> int: last_room = progress.get("last_room", "") txn.execute( """ @@ -1375,15 +1425,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return 0 async def _background_populate_rooms_creator_column( - self, progress: dict, batch_size: int - ): + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to go and add creator information to `rooms` table from `current_state_events` table. """ last_room_id = progress.get("room_id", "") - def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction): + def _background_populate_rooms_creator_column_txn( + txn: LoggingTransaction, + ) -> bool: sql = """ SELECT room_id, json FROM event_json INNER JOIN rooms AS room USING (room_id) @@ -1434,15 +1486,20 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return batch_size -class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): +class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) - self.config = hs.config + self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") async def upsert_room_on_join( self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase] - ): + ) -> None: """Ensure that the room is stored in the table Called when we join a room over federation, and overwrites any room version @@ -1488,7 +1545,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): async def maybe_store_room_on_outlier_membership( self, room_id: str, room_version: RoomVersion - ): + ) -> None: """ When we receive an invite or any other event over federation that may relate to a room we are not in, store the version of the room if we don't already know the room version. @@ -1528,8 +1585,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self.hs.get_notifier().on_new_replication_data() async def set_room_is_public_appservice( - self, room_id, appservice_id, network_id, is_public - ): + self, room_id: str, appservice_id: str, network_id: str, is_public: bool + ) -> None: """Edit the appservice/network specific public room list. Each appservice can have a number of published room lists associated @@ -1538,11 +1595,10 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): network. Args: - room_id (str) - appservice_id (str) - network_id (str) - is_public (bool): Whether to publish or unpublish the room from the - list. + room_id + appservice_id + network_id + is_public: Whether to publish or unpublish the room from the list. """ if is_public: @@ -1607,7 +1663,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): event_report: json list of information from event report """ - def _get_event_report_txn(txn, report_id): + def _get_event_report_txn( + txn: LoggingTransaction, report_id: int + ) -> Optional[Dict[str, Any]]: sql = """ SELECT @@ -1679,9 +1737,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): count: total number of event reports matching the filter criteria """ - def _get_event_reports_paginate_txn(txn): + def _get_event_reports_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], int]: filters = [] - args = [] + args: List[object] = [] if user_id: filters.append("er.user_id LIKE ?") @@ -1705,7 +1765,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): where_clause ) txn.execute(sql, args) - count = txn.fetchone()[0] + count = cast(Tuple[int], txn.fetchone())[0] sql = """ SELECT diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 6b2a8d06a6..cda80d6511 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -37,7 +37,7 @@ from synapse.metrics.background_process_metrics import ( wrap_as_background_process, ) from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( @@ -64,7 +64,12 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Used by `_get_joined_hosts` to ensure only one thing mutates the cache @@ -985,7 +990,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile @@ -1135,7 +1145,12 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) async def forget(self, user_id: str, room_id: str) -> None: diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 642560a70d..3cbaca21b5 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py
@@ -14,13 +14,18 @@ import logging import re -from collections import namedtuple from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set +import attr + from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -29,10 +34,15 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -SearchEntry = namedtuple( - "SearchEntry", - ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"], -) + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class SearchEntry: + key: str + value: str + event_id: str + room_id: str + stream_ordering: Optional[int] + origin_server_ts: int def _clean_value_for_search(value: str) -> str: @@ -105,7 +115,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if not hs.config.server.enable_search: @@ -358,7 +373,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) async def search_msgs(self, room_ids, search_term, keys): diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index fa2c3b1feb..2fb3e65192 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,6 @@ # limitations under the License. import collections.abc import logging -from collections import namedtuple from typing import TYPE_CHECKING, Iterable, Optional, Set from synapse.api.constants import EventTypes, Membership @@ -22,7 +21,11 @@ from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter @@ -39,24 +42,16 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 -class _GetStateGroupDelta( - namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) -): - """Return type of get_state_group_delta that implements __len__, which lets - us use the itrable flag when caching - """ - - __slots__ = [] - - def __len__(self): - return len(self.delta_ids) if self.delta_ids else 0 - - # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers.""" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) async def get_room_version(self, room_id: str) -> RoomVersion: @@ -182,11 +177,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): NotFoundError if the room is unknown """ state_ids = await self.get_current_state_ids(room_id) + + if not state_ids: + raise NotFoundError(f"Current state for room {room_id} is empty") + create_id = state_ids.get((EventTypes.Create, "")) # If we can't find the create event, assume we've hit a dead end if not create_id: - raise NotFoundError("Unknown room %s" % (room_id,)) + raise NotFoundError(f"No create event in current state for room {room_id}") # Retrieve the room's create event and return create_event = await self.get_event(create_id) @@ -349,7 +348,12 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -536,5 +540,10 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): * `state_groups_state`: Maps state group to state events. """ - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 7f3624b128..188afec332 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py
@@ -56,7 +56,9 @@ class StateDeltasStore(SQLBaseStore): prev_stream_id = int(prev_stream_id) # check we're not going backwards - assert prev_stream_id <= max_stream_id + assert ( + prev_stream_id <= max_stream_id + ), f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}" if not self._curr_state_delta_stream_cache.has_any_entity_changed( prev_stream_id diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 5d7b59d861..427ae1f649 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@ import logging from enum import Enum from itertools import chain -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast from typing_extensions import Counter @@ -24,7 +24,11 @@ from twisted.internet.defer import DeferredLock from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import StoreError -from synapse.storage.database import DatabasePool +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -96,7 +100,12 @@ class UserSortOrder(Enum): class StatsStore(StateDeltasStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -117,7 +126,9 @@ class StatsStore(StateDeltasStore): self.db_pool.updates.register_noop_background_update("populate_stats_cleanup") self.db_pool.updates.register_noop_background_update("populate_stats_prepare") - async def _populate_stats_process_users(self, progress, batch_size): + async def _populate_stats_process_users( + self, progress: JsonDict, batch_size: int + ) -> int: """ This is a background update which regenerates statistics for users. """ @@ -129,7 +140,7 @@ class StatsStore(StateDeltasStore): last_user_id = progress.get("last_user_id", "") - def _get_next_batch(txn): + def _get_next_batch(txn: LoggingTransaction) -> List[str]: sql = """ SELECT DISTINCT name FROM users WHERE name > ? @@ -163,7 +174,9 @@ class StatsStore(StateDeltasStore): return len(users_to_work_on) - async def _populate_stats_process_rooms(self, progress, batch_size): + async def _populate_stats_process_rooms( + self, progress: JsonDict, batch_size: int + ) -> int: """This is a background update which regenerates statistics for rooms.""" if not self.stats_enabled: await self.db_pool.updates._end_background_update( @@ -173,7 +186,7 @@ class StatsStore(StateDeltasStore): last_room_id = progress.get("last_room_id", "") - def _get_next_batch(txn): + def _get_next_batch(txn: LoggingTransaction) -> List[str]: sql = """ SELECT DISTINCT room_id FROM current_state_events WHERE room_id > ? @@ -302,7 +315,7 @@ class StatsStore(StateDeltasStore): stream_id: Current position. """ - def _bulk_update_stats_delta_txn(txn): + def _bulk_update_stats_delta_txn(txn: LoggingTransaction) -> None: for stats_type, stats_updates in updates.items(): for stats_id, fields in stats_updates.items(): logger.debug( @@ -334,7 +347,7 @@ class StatsStore(StateDeltasStore): stats_type: str, stats_id: str, fields: Dict[str, int], - complete_with_stream_id: Optional[int], + complete_with_stream_id: int, absolute_field_overrides: Optional[Dict[str, int]] = None, ) -> None: """ @@ -367,14 +380,14 @@ class StatsStore(StateDeltasStore): def _update_stats_delta_txn( self, - txn, - ts, - stats_type, - stats_id, - fields, - complete_with_stream_id, - absolute_field_overrides=None, - ): + txn: LoggingTransaction, + ts: int, + stats_type: str, + stats_id: str, + fields: Dict[str, int], + complete_with_stream_id: int, + absolute_field_overrides: Optional[Dict[str, int]] = None, + ) -> None: if absolute_field_overrides is None: absolute_field_overrides = {} @@ -417,20 +430,23 @@ class StatsStore(StateDeltasStore): ) def _upsert_with_additive_relatives_txn( - self, txn, table, keyvalues, absolutes, additive_relatives - ): + self, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + absolutes: Dict[str, Any], + additive_relatives: Dict[str, int], + ) -> None: """Used to update values in the stats tables. This is basically a slightly convoluted upsert that *adds* to any existing rows. Args: - txn - table (str): Table name - keyvalues (dict[str, any]): Row-identifying key values - absolutes (dict[str, any]): Absolute (set) fields - additive_relatives (dict[str, int]): Fields that will be added onto - if existing row present. + table: Table name + keyvalues: Row-identifying key values + absolutes: Absolute (set) fields + additive_relatives: Fields that will be added onto if existing row present. """ if self.database_engine.can_native_upsert: absolute_updates = [ @@ -486,20 +502,17 @@ class StatsStore(StateDeltasStore): current_row.update(absolutes) self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row) - async def _calculate_and_set_initial_state_for_room( - self, room_id: str - ) -> Tuple[dict, dict, int]: + async def _calculate_and_set_initial_state_for_room(self, room_id: str) -> None: """Calculate and insert an entry into room_stats_current. Args: room_id: The room ID under calculation. - - Returns: - A tuple of room state, membership counts and stream position. """ - def _fetch_current_state_stats(txn): - pos = self.get_room_max_stream_ordering() + def _fetch_current_state_stats( + txn: LoggingTransaction, + ) -> Tuple[List[str], Dict[str, int], int, List[str], int]: + pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined] rows = self.db_pool.simple_select_many_txn( txn, @@ -519,7 +532,7 @@ class StatsStore(StateDeltasStore): retcols=["event_id"], ) - event_ids = [row["event_id"] for row in rows] + event_ids = cast(List[str], [row["event_id"] for row in rows]) txn.execute( """ @@ -533,15 +546,15 @@ class StatsStore(StateDeltasStore): txn.execute( """ - SELECT COALESCE(count(*), 0) FROM current_state_events + SELECT COUNT(*) FROM current_state_events WHERE room_id = ? """, (room_id,), ) - (current_state_events_count,) = txn.fetchone() + current_state_events_count = cast(Tuple[int], txn.fetchone())[0] - users_in_room = self.get_users_in_room_txn(txn, room_id) + users_in_room = self.get_users_in_room_txn(txn, room_id) # type: ignore[attr-defined] return ( event_ids, @@ -561,7 +574,7 @@ class StatsStore(StateDeltasStore): "get_initial_state_for_room", _fetch_current_state_stats ) - state_event_map = await self.get_events(event_ids, get_prev_content=False) + state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined] room_state = { "join_rules": None, @@ -617,8 +630,10 @@ class StatsStore(StateDeltasStore): }, ) - async def _calculate_and_set_initial_state_for_user(self, user_id): - def _calculate_and_set_initial_state_for_user_txn(txn): + async def _calculate_and_set_initial_state_for_user(self, user_id: str) -> None: + def _calculate_and_set_initial_state_for_user_txn( + txn: LoggingTransaction, + ) -> Tuple[int, int]: pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) txn.execute( @@ -629,7 +644,7 @@ class StatsStore(StateDeltasStore): """, (user_id,), ) - (count,) = txn.fetchone() + count = cast(Tuple[int], txn.fetchone())[0] return count, pos joined_rooms, pos = await self.db_pool.runInteraction( @@ -673,7 +688,9 @@ class StatsStore(StateDeltasStore): users that exist given this query """ - def get_users_media_usage_paginate_txn(txn): + def get_users_media_usage_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], int]: filters = [] args = [self.hs.config.server.server_name] @@ -728,7 +745,7 @@ class StatsStore(StateDeltasStore): sql_base=sql_base, ) txn.execute(sql, args) - count = txn.fetchone()[0] + count = cast(Tuple[int], txn.fetchone())[0] sql = """ SELECT diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 57aab55259..319464b1fa 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py
@@ -34,11 +34,11 @@ what sort order was used: - topological tokems: "t%d-%d", where the integers map to the topological and stream ordering columns respectively. """ -import abc + import logging -from collections import namedtuple from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple +import attr from frozendict import frozendict from twisted.internet import defer @@ -49,6 +49,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_in_list_sql_clause, ) @@ -73,9 +74,11 @@ _TOPOLOGICAL_TOKEN = "topological" # Used as return values for pagination APIs -_EventDictReturn = namedtuple( - "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering") -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventDictReturn: + event_id: str + topological_ordering: Optional[int] + stream_ordering: int def generate_pagination_where_clause( @@ -333,13 +336,13 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: return " AND ".join(clauses), args -class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): - """This is an abstract base class where subclasses must implement - `get_room_max_stream_ordering` and `get_room_min_stream_ordering` - which can be called in the initializer. - """ - - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): +class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() @@ -371,13 +374,22 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): self._stream_order_on_start = self.get_room_max_stream_ordering() - @abc.abstractmethod def get_room_max_stream_ordering(self) -> int: - raise NotImplementedError() + """Get the stream_ordering of regular events that we have committed up to + + Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + """ + return self._stream_id_gen.get_current_token() - @abc.abstractmethod def get_room_min_stream_ordering(self) -> int: - raise NotImplementedError() + """Get the stream_ordering of backfilled events that we have committed up to + + Backfilled events use *negative* stream orderings, so this returns the + minimum negative stream id such that all stream ids greater than or + equal to it have been successfully persisted. + """ + return self._backfill_id_gen.get_current_token() def get_room_max_token(self) -> RoomStreamToken: """Get a `RoomStreamToken` that marks the current maximum persisted @@ -819,7 +831,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): for event, row in zip(events, rows): stream = row.stream_ordering if topo_order and row.topological_ordering: - topo = row.topological_ordering + topo: Optional[int] = row.topological_ordering else: topo = None internal = event.internal_metadata @@ -1343,11 +1355,3 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): retcol="instance_name", desc="get_name_from_instance_id", ) - - -class StreamStore(StreamWorkerStore): - def get_room_max_stream_ordering(self) -> int: - return self._stream_id_gen.get_current_token() - - def get_room_min_stream_ordering(self) -> int: - return self._backfill_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 8f510de53d..c8e508a910 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py
@@ -15,11 +15,13 @@ # limitations under the License. import logging -from typing import Dict, List, Tuple, cast +from typing import Any, Dict, Iterable, List, Tuple, cast +from synapse.replication.tcp.streams import TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -204,6 +206,7 @@ class TagsWorkerStore(AccountDataWorkerStore): The next account data ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) content_json = json_encoder.encode(content) @@ -230,6 +233,7 @@ class TagsWorkerStore(AccountDataWorkerStore): The next account data ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None: sql = ( @@ -258,6 +262,7 @@ class TagsWorkerStore(AccountDataWorkerStore): next_id: The the revision to advance to. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) txn.call_after( self._account_data_stream_cache.entity_has_changed, user_id, next_id @@ -287,6 +292,21 @@ class TagsWorkerStore(AccountDataWorkerStore): # than the id that the client has. pass + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + for row in rows: + self.get_tags_for_user.invalidate((row.user_id,)) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) + + super().process_replication_rows(stream_name, instance_name, token, rows) + class TagsStore(TagsWorkerStore): pass diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 1622822552..6c299cafa5 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -13,16 +13,19 @@ # limitations under the License. import logging -from collections import namedtuple from enum import Enum -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast import attr from canonicaljson import encode_canonical_json from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -35,16 +38,6 @@ db_binary_type = memoryview logger = logging.getLogger(__name__) -_TransactionRow = namedtuple( - "_TransactionRow", - ("id", "transaction_id", "destination", "ts", "response_code", "response_json"), -) - -_UpdateTransactionRow = namedtuple( - "_TransactionRow", ("response_code", "response_json") -) - - class DestinationSortOrder(Enum): """Enum to define the sorting method used when returning destinations.""" @@ -71,7 +64,12 @@ class DestinationRetryTimings: class TransactionWorkerStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -82,7 +80,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): now = self._clock.time_msec() month_ago = now - 30 * 24 * 60 * 60 * 1000 - def _cleanup_transactions_txn(txn): + def _cleanup_transactions_txn(txn: LoggingTransaction) -> None: txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) await self.db_pool.runInteraction( @@ -112,7 +110,9 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): origin, ) - def _get_received_txn_response(self, txn, transaction_id, origin): + def _get_received_txn_response( + self, txn: LoggingTransaction, transaction_id: str, origin: str + ) -> Optional[Tuple[int, JsonDict]]: result = self.db_pool.simple_select_one_txn( txn, table="received_transactions", @@ -187,7 +187,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): return result def _get_destination_retry_timings( - self, txn, destination: str + self, txn: LoggingTransaction, destination: str ) -> Optional[DestinationRetryTimings]: result = self.db_pool.simple_select_one_txn( txn, @@ -222,7 +222,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): """ if self.database_engine.can_native_upsert: - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings_native, destination, @@ -232,7 +232,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): db_autocommit=True, # Safe as its a single upsert ) else: - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings_emulated, destination, @@ -242,8 +242,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): ) def _set_destination_retry_timings_native( - self, txn, destination, failure_ts, retry_last_ts, retry_interval - ): + self, + txn: LoggingTransaction, + destination: str, + failure_ts: Optional[int], + retry_last_ts: int, + retry_interval: int, + ) -> None: assert self.database_engine.can_native_upsert # Upsert retry time interval if retry_interval is zero (i.e. we're @@ -273,8 +278,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): ) def _set_destination_retry_timings_emulated( - self, txn, destination, failure_ts, retry_last_ts, retry_interval - ): + self, + txn: LoggingTransaction, + destination: str, + failure_ts: Optional[int], + retry_last_ts: int, + retry_interval: int, + ) -> None: self.database_engine.lock_table(txn, "destinations") # We need to be careful here as the data may have changed from under us @@ -384,7 +394,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): last_successful_stream_ordering: the stream_ordering of the most recent successfully-sent PDU """ - return await self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( "destinations", keyvalues={"destination": destination}, values={"last_successful_stream_ordering": last_successful_stream_ordering}, @@ -525,7 +535,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): else: order = "ASC" - args = [] + args: List[object] = [] where_statement = "" if destination: args.extend(["%" + destination.lower() + "%"]) @@ -534,7 +544,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): sql_base = f"FROM destinations {where_statement} " sql = f"SELECT COUNT(*) as total_destinations {sql_base}" txn.execute(sql, args) - count = txn.fetchone()[0] + count = cast(Tuple[int], txn.fetchone())[0] sql = f""" SELECT destination, retry_last_ts, retry_interval, failure_ts, diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 340ca9e47d..a1a1a6a14a 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py
@@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast import attr @@ -225,11 +225,14 @@ class UIAuthWorkerStore(SQLBaseStore): self, txn: LoggingTransaction, session_id: str, key: str, value: Any ): # Get the current value. - result: Dict[str, Any] = self.db_pool.simple_select_one_txn( # type: ignore - txn, - table="ui_auth_sessions", - keyvalues={"session_id": session_id}, - retcols=("serverdict",), + result = cast( + Dict[str, Any], + self.db_pool.simple_select_one_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("serverdict",), + ), ) # Update it and add it back to the database. diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index e98a45b6af..0f9b8575d3 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py
@@ -32,11 +32,14 @@ if TYPE_CHECKING: from synapse.server import HomeServer from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.storage.types import Connection from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached @@ -53,7 +56,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) @@ -592,7 +595,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ) -> None: super().__init__(database, db_conn, hs) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 50d08094d5..2a3d47185a 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 66 # remember to update the list below when updating +SCHEMA_VERSION = 67 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -50,6 +50,9 @@ Changes in SCHEMA_VERSION = 65: Changes in SCHEMA_VERSION = 66: - Queries on state_key columns are now disambiguated (ie, the codebase can handle the `events` table having a `state_key` column). + +Changes in SCHEMA_VERSION = 67: + - state_events.prev_state is no longer written to. """ diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 4ff3013908..b8112e1c05 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py
@@ -74,8 +74,6 @@ class IdGenerator: def _load_current_id( db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1 ) -> int: - # debug logging for https://github.com/matrix-org/synapse/issues/7968 - logger.info("initialising stream generator for %s(%s)", table, column) cur = db_conn.cursor(txn_name="_load_current_id") if step == 1: cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) @@ -86,7 +84,9 @@ def _load_current_id( (val,) = result cur.close() current_id = int(val) if val else step - return (max if step > 0 else min)(current_id, step) + res = (max if step > 0 else min)(current_id, step) + logger.info("Initialising stream generator for %s(%s): %i", table, column, res) + return res class AbstractStreamIdTracker(metaclass=abc.ABCMeta):